sklears_feature_selection/
streaming.rs

1//! Streaming and online feature selection methods
2//!
3//! This module provides algorithms for feature selection that can process data incrementally,
4//! handle concept drift, and adapt their selection over time.
5
6use crate::base::FeatureSelector;
7use scirs2_core::ndarray::{Array1, Array2, Axis};
8use sklears_core::{
9    error::{Result as SklResult, SklearsError},
10    traits::{Estimator, Fit, Trained, Transform, Untrained},
11    types::Float,
12};
13use std::collections::{HashMap, VecDeque};
14use std::marker::PhantomData;
15
16/// Online feature selection using incremental statistics
17///
18/// This selector maintains running statistics for each feature and updates
19/// feature selection based on these statistics as new data arrives.
20#[derive(Debug, Clone)]
21pub struct OnlineFeatureSelector<State = Untrained> {
22    // Configuration
23    k: usize,
24    window_size: Option<usize>,
25    decay_factor: f64,
26    min_samples: usize,
27
28    // Online statistics
29    feature_means_: Option<Array1<Float>>,
30    feature_vars_: Option<Array1<Float>>,
31    sample_count_: usize,
32    target_correlation_: Option<Array1<Float>>,
33    selected_features_: Option<Vec<usize>>,
34    n_features_: Option<usize>,
35
36    // Sliding window for concept drift detection
37    window_data_: Option<VecDeque<Array1<Float>>>,
38    window_targets_: Option<VecDeque<Float>>,
39
40    state: PhantomData<State>,
41}
42
43impl OnlineFeatureSelector<Untrained> {
44    /// Create a new online feature selector
45    pub fn new(k: usize) -> Self {
46        Self {
47            k,
48            window_size: None,
49            decay_factor: 0.95,
50            min_samples: 10,
51            feature_means_: None,
52            feature_vars_: None,
53            sample_count_: 0,
54            target_correlation_: None,
55            selected_features_: None,
56            n_features_: None,
57            window_data_: None,
58            window_targets_: None,
59            state: PhantomData,
60        }
61    }
62
63    /// Set the window size for concept drift detection
64    pub fn window_size(mut self, window_size: usize) -> Self {
65        self.window_size = Some(window_size);
66        self
67    }
68
69    /// Set the decay factor for exponential moving statistics
70    pub fn decay_factor(mut self, decay_factor: f64) -> Self {
71        if !(0.0..=1.0).contains(&decay_factor) {
72            panic!("decay_factor must be between 0 and 1");
73        }
74        self.decay_factor = decay_factor;
75        self
76    }
77
78    /// Set minimum samples before selection begins
79    pub fn min_samples(mut self, min_samples: usize) -> Self {
80        self.min_samples = min_samples;
81        self
82    }
83}
84
85impl Default for OnlineFeatureSelector<Untrained> {
86    fn default() -> Self {
87        Self::new(10)
88    }
89}
90
91impl Estimator for OnlineFeatureSelector<Untrained> {
92    type Config = ();
93    type Error = SklearsError;
94    type Float = f64;
95
96    fn config(&self) -> &Self::Config {
97        &()
98    }
99}
100
101impl Fit<Array2<Float>, Array1<Float>> for OnlineFeatureSelector<Untrained> {
102    type Fitted = OnlineFeatureSelector<Trained>;
103
104    fn fit(self, x: &Array2<Float>, y: &Array1<Float>) -> SklResult<Self::Fitted> {
105        let (n_samples, n_features) = x.dim();
106        if n_samples == 0 || n_features == 0 {
107            return Err(SklearsError::InvalidInput(
108                "Input data cannot be empty".to_string(),
109            ));
110        }
111
112        if self.k > n_features {
113            return Err(SklearsError::InvalidInput(
114                "k cannot be larger than number of features".to_string(),
115            ));
116        }
117
118        let mut selector = OnlineFeatureSelector {
119            k: self.k,
120            window_size: self.window_size,
121            decay_factor: self.decay_factor,
122            min_samples: self.min_samples,
123            feature_means_: Some(Array1::zeros(n_features)),
124            feature_vars_: Some(Array1::zeros(n_features)),
125            sample_count_: 0,
126            target_correlation_: Some(Array1::zeros(n_features)),
127            selected_features_: Some(Vec::new()),
128            n_features_: Some(n_features),
129            window_data_: if self.window_size.is_some() {
130                Some(VecDeque::new())
131            } else {
132                None
133            },
134            window_targets_: if self.window_size.is_some() {
135                Some(VecDeque::new())
136            } else {
137                None
138            },
139            state: PhantomData,
140        };
141
142        // Process initial data
143        for (sample_idx, target) in x.axis_iter(Axis(0)).zip(y.iter()) {
144            selector.partial_fit_sample(&sample_idx.to_owned(), *target)?;
145        }
146
147        Ok(selector)
148    }
149}
150
151impl OnlineFeatureSelector<Trained> {
152    /// Update the selector with a new sample
153    pub fn partial_fit_sample(&mut self, sample: &Array1<Float>, target: Float) -> SklResult<()> {
154        let n_features = sample.len();
155
156        if let Some(expected_features) = self.n_features_ {
157            if n_features != expected_features {
158                return Err(SklearsError::InvalidInput(
159                    "Sample has different number of features than expected".to_string(),
160                ));
161            }
162        } else {
163            self.n_features_ = Some(n_features);
164            self.feature_means_ = Some(Array1::zeros(n_features));
165            self.feature_vars_ = Some(Array1::zeros(n_features));
166            self.target_correlation_ = Some(Array1::zeros(n_features));
167        }
168
169        // Update sliding window if enabled
170        if let (Some(window_data), Some(window_targets)) =
171            (self.window_data_.as_mut(), self.window_targets_.as_mut())
172        {
173            if let Some(window_size) = self.window_size {
174                window_data.push_back(sample.clone());
175                window_targets.push_back(target);
176
177                if window_data.len() > window_size {
178                    window_data.pop_front();
179                    window_targets.pop_front();
180                }
181            }
182        }
183
184        // Update exponential moving statistics
185        if let (Some(means), Some(vars), Some(correlations)) = (
186            self.feature_means_.as_mut(),
187            self.feature_vars_.as_mut(),
188            self.target_correlation_.as_mut(),
189        ) {
190            self.sample_count_ += 1;
191            let alpha = if self.sample_count_ == 1 {
192                1.0
193            } else {
194                1.0 - self.decay_factor
195            };
196
197            for (i, &value) in sample.iter().enumerate() {
198                // Update mean
199                let old_mean = means[i];
200                means[i] = alpha * value + (1.0 - alpha) * old_mean;
201
202                // Update variance (using Welford's online algorithm)
203                let delta = value - old_mean;
204                let delta2 = value - means[i];
205                vars[i] = (1.0 - alpha) * vars[i] + alpha * delta * delta2;
206
207                // Update correlation with target (simplified)
208                let target_mean = 0.0; // Simplified - in practice maintain running target mean
209                let target_centered = target - target_mean;
210                let feature_centered = value - means[i];
211                correlations[i] =
212                    alpha * (feature_centered * target_centered) + (1.0 - alpha) * correlations[i];
213            }
214        }
215
216        // Update feature selection if we have enough samples
217        if self.sample_count_ >= self.min_samples {
218            self.update_feature_selection()?;
219        }
220
221        Ok(())
222    }
223
224    /// Update feature selection based on current statistics
225    fn update_feature_selection(&mut self) -> SklResult<()> {
226        if let Some(correlations) = &self.target_correlation_ {
227            // Select features with highest absolute correlation
228            let mut feature_scores: Vec<(usize, f64)> = correlations
229                .iter()
230                .enumerate()
231                .map(|(i, &corr)| (i, corr.abs()))
232                .collect();
233
234            // Sort by score (descending)
235            feature_scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
236
237            // Select top k features
238            let selected: Vec<usize> = feature_scores
239                .into_iter()
240                .take(self.k)
241                .map(|(idx, _)| idx)
242                .collect();
243
244            self.selected_features_ = Some(selected);
245        }
246
247        Ok(())
248    }
249
250    /// Compute current target mean (simplified for this example)
251    fn compute_target_mean(&self) -> Float {
252        // This is a simplified implementation
253        // In practice, you'd maintain a running mean of targets
254        0.0
255    }
256
257    /// Detect concept drift using window statistics
258    pub fn detect_concept_drift(&self) -> SklResult<bool> {
259        if let (Some(window_data), Some(window_targets)) =
260            (&self.window_data_, &self.window_targets_)
261        {
262            if window_data.len() < 20 {
263                return Ok(false); // Not enough data
264            }
265
266            // Simple drift detection: compare first and second half of window
267            let mid = window_data.len() / 2;
268            let first_half_targets: Vec<Float> = window_targets.iter().take(mid).cloned().collect();
269            let second_half_targets: Vec<Float> =
270                window_targets.iter().skip(mid).cloned().collect();
271
272            // Compute means
273            let first_mean =
274                first_half_targets.iter().sum::<Float>() / first_half_targets.len() as Float;
275            let second_mean =
276                second_half_targets.iter().sum::<Float>() / second_half_targets.len() as Float;
277
278            // Simple threshold-based drift detection
279            let drift_threshold = 0.5;
280            Ok((first_mean - second_mean).abs() > drift_threshold)
281        } else {
282            Ok(false)
283        }
284    }
285
286    /// Reset the selector (useful when concept drift is detected)
287    pub fn reset(&mut self) -> SklResult<()> {
288        if let Some(n_features) = self.n_features_ {
289            self.feature_means_ = Some(Array1::zeros(n_features));
290            self.feature_vars_ = Some(Array1::zeros(n_features));
291            self.target_correlation_ = Some(Array1::zeros(n_features));
292            self.sample_count_ = 0;
293
294            if let Some(window_data) = self.window_data_.as_mut() {
295                window_data.clear();
296            }
297            if let Some(window_targets) = self.window_targets_.as_mut() {
298                window_targets.clear();
299            }
300        }
301        Ok(())
302    }
303}
304
305impl FeatureSelector for OnlineFeatureSelector<Trained> {
306    fn selected_features(&self) -> &Vec<usize> {
307        match &self.selected_features_ {
308            Some(features) => features,
309            None => {
310                static EMPTY: Vec<usize> = Vec::new();
311                &EMPTY
312            }
313        }
314    }
315}
316
317impl Transform<Array2<Float>, Array2<Float>> for OnlineFeatureSelector<Trained> {
318    fn transform(&self, x: &Array2<Float>) -> SklResult<Array2<Float>> {
319        if let Some(selected) = &self.selected_features_ {
320            if selected.is_empty() {
321                return Err(SklearsError::InvalidData {
322                    reason: "No features selected yet".to_string(),
323                });
324            }
325
326            let selected_cols = x.select(Axis(1), selected);
327            Ok(selected_cols)
328        } else {
329            Err(SklearsError::InvalidData {
330                reason: "Selector not fitted yet".to_string(),
331            })
332        }
333    }
334}
335
336/// Streaming feature importance calculator
337///
338/// Maintains running importance scores for features in a streaming fashion
339#[derive(Debug, Clone)]
340pub struct StreamingFeatureImportance {
341    // Configuration
342    decay_factor: f64,
343    min_samples: usize,
344
345    // State
346    importance_scores_: HashMap<usize, Float>,
347    sample_count_: usize,
348    n_features_: Option<usize>,
349}
350
351impl StreamingFeatureImportance {
352    /// Create a new streaming feature importance calculator
353    pub fn new() -> Self {
354        Self {
355            decay_factor: 0.95,
356            min_samples: 10,
357            importance_scores_: HashMap::new(),
358            sample_count_: 0,
359            n_features_: None,
360        }
361    }
362
363    /// Set decay factor for exponential moving average
364    pub fn decay_factor(mut self, decay_factor: f64) -> Self {
365        if !(0.0..=1.0).contains(&decay_factor) {
366            panic!("decay_factor must be between 0 and 1");
367        }
368        self.decay_factor = decay_factor;
369        self
370    }
371
372    /// Update importance scores with new sample
373    pub fn update(
374        &mut self,
375        features: &Array1<Float>,
376        target: Float,
377        prediction: Float,
378    ) -> SklResult<()> {
379        let n_features = features.len();
380
381        if let Some(expected) = self.n_features_ {
382            if n_features != expected {
383                return Err(SklearsError::InvalidInput(
384                    "Inconsistent number of features".to_string(),
385                ));
386            }
387        } else {
388            self.n_features_ = Some(n_features);
389        }
390
391        self.sample_count_ += 1;
392        let prediction_error = (target - prediction).abs();
393
394        // Update importance based on feature values and prediction error
395        for (i, &feature_value) in features.iter().enumerate() {
396            let contribution = feature_value.abs() * prediction_error;
397
398            let current_importance = self.importance_scores_.get(&i).cloned().unwrap_or(0.0);
399            let alpha = 1.0 - self.decay_factor;
400            let new_importance = alpha * contribution + self.decay_factor * current_importance;
401
402            self.importance_scores_.insert(i, new_importance);
403        }
404
405        Ok(())
406    }
407
408    /// Get current importance scores
409    pub fn get_importance_scores(&self) -> &HashMap<usize, Float> {
410        &self.importance_scores_
411    }
412
413    /// Get top k most important features
414    pub fn get_top_features(&self, k: usize) -> Vec<usize> {
415        let mut scores: Vec<(usize, Float)> = self
416            .importance_scores_
417            .iter()
418            .map(|(&idx, &score)| (idx, score))
419            .collect();
420
421        scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
422        scores.into_iter().take(k).map(|(idx, _)| idx).collect()
423    }
424}
425
426impl Default for StreamingFeatureImportance {
427    fn default() -> Self {
428        Self::new()
429    }
430}
431
432/// Concept drift-aware feature selector
433///
434/// Adapts feature selection when concept drift is detected
435#[derive(Debug, Clone)]
436pub struct ConceptDriftAwareSelector<State = Untrained> {
437    base_selector: OnlineFeatureSelector<State>,
438    drift_detection_window: usize,
439    drift_threshold: f64,
440    adaptation_rate: f64,
441
442    // Drift detection state
443    performance_history_: VecDeque<Float>,
444    drift_detected_: bool,
445}
446
447impl ConceptDriftAwareSelector<Untrained> {
448    /// Create a new concept drift-aware selector
449    pub fn new(k: usize) -> Self {
450        Self {
451            base_selector: OnlineFeatureSelector::new(k),
452            drift_detection_window: 100,
453            drift_threshold: 0.05,
454            adaptation_rate: 0.1,
455            performance_history_: VecDeque::new(),
456            drift_detected_: false,
457        }
458    }
459
460    /// Set drift detection window size
461    pub fn drift_detection_window(mut self, window_size: usize) -> Self {
462        self.drift_detection_window = window_size;
463        self
464    }
465
466    /// Set drift detection threshold
467    pub fn drift_threshold(mut self, threshold: f64) -> Self {
468        self.drift_threshold = threshold;
469        self
470    }
471
472    /// Set minimum samples for feature selection
473    pub fn min_samples(mut self, min_samples: usize) -> Self {
474        self.base_selector = self.base_selector.min_samples(min_samples);
475        self
476    }
477}
478
479impl Estimator for ConceptDriftAwareSelector<Untrained> {
480    type Config = ();
481    type Error = SklearsError;
482    type Float = f64;
483
484    fn config(&self) -> &Self::Config {
485        &()
486    }
487}
488
489impl Fit<Array2<Float>, Array1<Float>> for ConceptDriftAwareSelector<Untrained> {
490    type Fitted = ConceptDriftAwareSelector<Trained>;
491
492    fn fit(self, x: &Array2<Float>, y: &Array1<Float>) -> SklResult<Self::Fitted> {
493        let fitted_base = self.base_selector.fit(x, y)?;
494
495        Ok(ConceptDriftAwareSelector {
496            base_selector: fitted_base,
497            drift_detection_window: self.drift_detection_window,
498            drift_threshold: self.drift_threshold,
499            adaptation_rate: self.adaptation_rate,
500            performance_history_: VecDeque::new(),
501            drift_detected_: false,
502        })
503    }
504}
505
506impl ConceptDriftAwareSelector<Trained> {
507    /// Update with new sample and check for drift
508    pub fn partial_fit_with_performance(
509        &mut self,
510        sample: &Array1<Float>,
511        target: Float,
512        performance: Float,
513    ) -> SklResult<()> {
514        // Update base selector
515        self.base_selector.partial_fit_sample(sample, target)?;
516
517        // Update performance history
518        self.performance_history_.push_back(performance);
519        if self.performance_history_.len() > self.drift_detection_window {
520            self.performance_history_.pop_front();
521        }
522
523        // Check for drift
524        if self.performance_history_.len() >= self.drift_detection_window / 2 {
525            self.drift_detected_ = self.detect_performance_drift()?;
526
527            if self.drift_detected_ {
528                // Adapt to drift by partially resetting the selector
529                self.adapt_to_drift()?;
530            }
531        }
532
533        Ok(())
534    }
535
536    /// Detect drift based on performance degradation
537    fn detect_performance_drift(&self) -> SklResult<bool> {
538        if self.performance_history_.len() < 20 {
539            return Ok(false);
540        }
541
542        let mid = self.performance_history_.len() / 2;
543        let recent_perf: Float = self.performance_history_.iter().skip(mid).sum::<Float>()
544            / (self.performance_history_.len() - mid) as Float;
545
546        let old_perf: Float =
547            self.performance_history_.iter().take(mid).sum::<Float>() / mid as Float;
548
549        // Drift detected if recent performance is significantly worse
550        Ok(old_perf - recent_perf > self.drift_threshold)
551    }
552
553    /// Adapt to detected concept drift
554    fn adapt_to_drift(&mut self) -> SklResult<()> {
555        // Partially reset the base selector to adapt to new concept
556        // This is a simplified adaptation strategy
557
558        // Clear some of the history to focus on recent data
559        let reset_fraction = self.adaptation_rate;
560        let samples_to_keep =
561            ((1.0 - reset_fraction) * self.performance_history_.len() as f64) as usize;
562
563        while self.performance_history_.len() > samples_to_keep {
564            self.performance_history_.pop_front();
565        }
566
567        self.drift_detected_ = false;
568        Ok(())
569    }
570
571    /// Check if drift was recently detected
572    pub fn drift_detected(&self) -> bool {
573        self.drift_detected_
574    }
575}
576
577impl FeatureSelector for ConceptDriftAwareSelector<Trained> {
578    fn selected_features(&self) -> &Vec<usize> {
579        self.base_selector.selected_features()
580    }
581}
582
583impl Transform<Array2<Float>, Array2<Float>> for ConceptDriftAwareSelector<Trained> {
584    fn transform(&self, x: &Array2<Float>) -> SklResult<Array2<Float>> {
585        self.base_selector.transform(x)
586    }
587}
588
589#[allow(non_snake_case)]
590#[cfg(test)]
591mod tests {
592    use super::*;
593
594    use scirs2_core::ndarray::array;
595
596    #[test]
597    fn test_online_feature_selector_basic() {
598        let x = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]];
599        let y = array![1.0, 2.0, 3.0];
600
601        let selector = OnlineFeatureSelector::new(2).min_samples(2);
602        let fitted = selector.fit(&x, &y).unwrap();
603
604        assert_eq!(fitted.selected_features().len(), 2);
605        assert_eq!(fitted.sample_count_, 3);
606    }
607
608    #[test]
609    fn test_online_selector_partial_fit() {
610        let x = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]];
611        let y = array![1.0, 2.0];
612
613        let selector = OnlineFeatureSelector::new(2);
614        let mut fitted = selector.fit(&x, &y).unwrap();
615
616        // Add new sample
617        let new_sample = array![10.0, 11.0, 12.0];
618        fitted.partial_fit_sample(&new_sample, 3.0).unwrap();
619
620        assert_eq!(fitted.sample_count_, 3);
621    }
622
623    #[test]
624    fn test_streaming_importance() {
625        let mut importance = StreamingFeatureImportance::new();
626
627        let features = array![1.0, 2.0, 3.0];
628        importance.update(&features, 5.0, 4.8).unwrap();
629
630        let scores = importance.get_importance_scores();
631        assert_eq!(scores.len(), 3);
632
633        let top_features = importance.get_top_features(2);
634        assert_eq!(top_features.len(), 2);
635    }
636
637    #[test]
638    fn test_concept_drift_selector() {
639        let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
640        let y = array![1.0, 2.0, 3.0];
641
642        let selector = ConceptDriftAwareSelector::new(1).min_samples(2);
643        let mut fitted = selector.fit(&x, &y).unwrap();
644
645        // Add sample with performance
646        let sample = array![7.0, 8.0];
647        fitted
648            .partial_fit_with_performance(&sample, 4.0, 0.9)
649            .unwrap();
650
651        assert_eq!(fitted.selected_features().len(), 1);
652    }
653
654    #[test]
655    fn test_online_selector_transform() {
656        let x = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]];
657        let y = array![1.0, 2.0];
658
659        let selector = OnlineFeatureSelector::new(2).min_samples(2);
660        let fitted = selector.fit(&x, &y).unwrap();
661
662        let test_x = array![[7.0, 8.0, 9.0], [10.0, 11.0, 12.0]];
663        let transformed = fitted.transform(&test_x).unwrap();
664
665        assert_eq!(transformed.ncols(), 2);
666        assert_eq!(transformed.nrows(), 2);
667    }
668}