1use crate::base::FeatureSelector;
7use scirs2_core::ndarray::{Array1, Array2, Axis};
8use sklears_core::{
9 error::{Result as SklResult, SklearsError},
10 traits::{Estimator, Fit, Trained, Transform, Untrained},
11 types::Float,
12};
13use std::collections::{HashMap, VecDeque};
14use std::marker::PhantomData;
15
16#[derive(Debug, Clone)]
21pub struct OnlineFeatureSelector<State = Untrained> {
22 k: usize,
24 window_size: Option<usize>,
25 decay_factor: f64,
26 min_samples: usize,
27
28 feature_means_: Option<Array1<Float>>,
30 feature_vars_: Option<Array1<Float>>,
31 sample_count_: usize,
32 target_correlation_: Option<Array1<Float>>,
33 selected_features_: Option<Vec<usize>>,
34 n_features_: Option<usize>,
35
36 window_data_: Option<VecDeque<Array1<Float>>>,
38 window_targets_: Option<VecDeque<Float>>,
39
40 state: PhantomData<State>,
41}
42
43impl OnlineFeatureSelector<Untrained> {
44 pub fn new(k: usize) -> Self {
46 Self {
47 k,
48 window_size: None,
49 decay_factor: 0.95,
50 min_samples: 10,
51 feature_means_: None,
52 feature_vars_: None,
53 sample_count_: 0,
54 target_correlation_: None,
55 selected_features_: None,
56 n_features_: None,
57 window_data_: None,
58 window_targets_: None,
59 state: PhantomData,
60 }
61 }
62
63 pub fn window_size(mut self, window_size: usize) -> Self {
65 self.window_size = Some(window_size);
66 self
67 }
68
69 pub fn decay_factor(mut self, decay_factor: f64) -> Self {
71 if !(0.0..=1.0).contains(&decay_factor) {
72 panic!("decay_factor must be between 0 and 1");
73 }
74 self.decay_factor = decay_factor;
75 self
76 }
77
78 pub fn min_samples(mut self, min_samples: usize) -> Self {
80 self.min_samples = min_samples;
81 self
82 }
83}
84
85impl Default for OnlineFeatureSelector<Untrained> {
86 fn default() -> Self {
87 Self::new(10)
88 }
89}
90
91impl Estimator for OnlineFeatureSelector<Untrained> {
92 type Config = ();
93 type Error = SklearsError;
94 type Float = f64;
95
96 fn config(&self) -> &Self::Config {
97 &()
98 }
99}
100
101impl Fit<Array2<Float>, Array1<Float>> for OnlineFeatureSelector<Untrained> {
102 type Fitted = OnlineFeatureSelector<Trained>;
103
104 fn fit(self, x: &Array2<Float>, y: &Array1<Float>) -> SklResult<Self::Fitted> {
105 let (n_samples, n_features) = x.dim();
106 if n_samples == 0 || n_features == 0 {
107 return Err(SklearsError::InvalidInput(
108 "Input data cannot be empty".to_string(),
109 ));
110 }
111
112 if self.k > n_features {
113 return Err(SklearsError::InvalidInput(
114 "k cannot be larger than number of features".to_string(),
115 ));
116 }
117
118 let mut selector = OnlineFeatureSelector {
119 k: self.k,
120 window_size: self.window_size,
121 decay_factor: self.decay_factor,
122 min_samples: self.min_samples,
123 feature_means_: Some(Array1::zeros(n_features)),
124 feature_vars_: Some(Array1::zeros(n_features)),
125 sample_count_: 0,
126 target_correlation_: Some(Array1::zeros(n_features)),
127 selected_features_: Some(Vec::new()),
128 n_features_: Some(n_features),
129 window_data_: if self.window_size.is_some() {
130 Some(VecDeque::new())
131 } else {
132 None
133 },
134 window_targets_: if self.window_size.is_some() {
135 Some(VecDeque::new())
136 } else {
137 None
138 },
139 state: PhantomData,
140 };
141
142 for (sample_idx, target) in x.axis_iter(Axis(0)).zip(y.iter()) {
144 selector.partial_fit_sample(&sample_idx.to_owned(), *target)?;
145 }
146
147 Ok(selector)
148 }
149}
150
151impl OnlineFeatureSelector<Trained> {
152 pub fn partial_fit_sample(&mut self, sample: &Array1<Float>, target: Float) -> SklResult<()> {
154 let n_features = sample.len();
155
156 if let Some(expected_features) = self.n_features_ {
157 if n_features != expected_features {
158 return Err(SklearsError::InvalidInput(
159 "Sample has different number of features than expected".to_string(),
160 ));
161 }
162 } else {
163 self.n_features_ = Some(n_features);
164 self.feature_means_ = Some(Array1::zeros(n_features));
165 self.feature_vars_ = Some(Array1::zeros(n_features));
166 self.target_correlation_ = Some(Array1::zeros(n_features));
167 }
168
169 if let (Some(window_data), Some(window_targets)) =
171 (self.window_data_.as_mut(), self.window_targets_.as_mut())
172 {
173 if let Some(window_size) = self.window_size {
174 window_data.push_back(sample.clone());
175 window_targets.push_back(target);
176
177 if window_data.len() > window_size {
178 window_data.pop_front();
179 window_targets.pop_front();
180 }
181 }
182 }
183
184 if let (Some(means), Some(vars), Some(correlations)) = (
186 self.feature_means_.as_mut(),
187 self.feature_vars_.as_mut(),
188 self.target_correlation_.as_mut(),
189 ) {
190 self.sample_count_ += 1;
191 let alpha = if self.sample_count_ == 1 {
192 1.0
193 } else {
194 1.0 - self.decay_factor
195 };
196
197 for (i, &value) in sample.iter().enumerate() {
198 let old_mean = means[i];
200 means[i] = alpha * value + (1.0 - alpha) * old_mean;
201
202 let delta = value - old_mean;
204 let delta2 = value - means[i];
205 vars[i] = (1.0 - alpha) * vars[i] + alpha * delta * delta2;
206
207 let target_mean = 0.0; let target_centered = target - target_mean;
210 let feature_centered = value - means[i];
211 correlations[i] =
212 alpha * (feature_centered * target_centered) + (1.0 - alpha) * correlations[i];
213 }
214 }
215
216 if self.sample_count_ >= self.min_samples {
218 self.update_feature_selection()?;
219 }
220
221 Ok(())
222 }
223
224 fn update_feature_selection(&mut self) -> SklResult<()> {
226 if let Some(correlations) = &self.target_correlation_ {
227 let mut feature_scores: Vec<(usize, f64)> = correlations
229 .iter()
230 .enumerate()
231 .map(|(i, &corr)| (i, corr.abs()))
232 .collect();
233
234 feature_scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
236
237 let selected: Vec<usize> = feature_scores
239 .into_iter()
240 .take(self.k)
241 .map(|(idx, _)| idx)
242 .collect();
243
244 self.selected_features_ = Some(selected);
245 }
246
247 Ok(())
248 }
249
250 fn compute_target_mean(&self) -> Float {
252 0.0
255 }
256
257 pub fn detect_concept_drift(&self) -> SklResult<bool> {
259 if let (Some(window_data), Some(window_targets)) =
260 (&self.window_data_, &self.window_targets_)
261 {
262 if window_data.len() < 20 {
263 return Ok(false); }
265
266 let mid = window_data.len() / 2;
268 let first_half_targets: Vec<Float> = window_targets.iter().take(mid).cloned().collect();
269 let second_half_targets: Vec<Float> =
270 window_targets.iter().skip(mid).cloned().collect();
271
272 let first_mean =
274 first_half_targets.iter().sum::<Float>() / first_half_targets.len() as Float;
275 let second_mean =
276 second_half_targets.iter().sum::<Float>() / second_half_targets.len() as Float;
277
278 let drift_threshold = 0.5;
280 Ok((first_mean - second_mean).abs() > drift_threshold)
281 } else {
282 Ok(false)
283 }
284 }
285
286 pub fn reset(&mut self) -> SklResult<()> {
288 if let Some(n_features) = self.n_features_ {
289 self.feature_means_ = Some(Array1::zeros(n_features));
290 self.feature_vars_ = Some(Array1::zeros(n_features));
291 self.target_correlation_ = Some(Array1::zeros(n_features));
292 self.sample_count_ = 0;
293
294 if let Some(window_data) = self.window_data_.as_mut() {
295 window_data.clear();
296 }
297 if let Some(window_targets) = self.window_targets_.as_mut() {
298 window_targets.clear();
299 }
300 }
301 Ok(())
302 }
303}
304
305impl FeatureSelector for OnlineFeatureSelector<Trained> {
306 fn selected_features(&self) -> &Vec<usize> {
307 match &self.selected_features_ {
308 Some(features) => features,
309 None => {
310 static EMPTY: Vec<usize> = Vec::new();
311 &EMPTY
312 }
313 }
314 }
315}
316
317impl Transform<Array2<Float>, Array2<Float>> for OnlineFeatureSelector<Trained> {
318 fn transform(&self, x: &Array2<Float>) -> SklResult<Array2<Float>> {
319 if let Some(selected) = &self.selected_features_ {
320 if selected.is_empty() {
321 return Err(SklearsError::InvalidData {
322 reason: "No features selected yet".to_string(),
323 });
324 }
325
326 let selected_cols = x.select(Axis(1), selected);
327 Ok(selected_cols)
328 } else {
329 Err(SklearsError::InvalidData {
330 reason: "Selector not fitted yet".to_string(),
331 })
332 }
333 }
334}
335
336#[derive(Debug, Clone)]
340pub struct StreamingFeatureImportance {
341 decay_factor: f64,
343 min_samples: usize,
344
345 importance_scores_: HashMap<usize, Float>,
347 sample_count_: usize,
348 n_features_: Option<usize>,
349}
350
351impl StreamingFeatureImportance {
352 pub fn new() -> Self {
354 Self {
355 decay_factor: 0.95,
356 min_samples: 10,
357 importance_scores_: HashMap::new(),
358 sample_count_: 0,
359 n_features_: None,
360 }
361 }
362
363 pub fn decay_factor(mut self, decay_factor: f64) -> Self {
365 if !(0.0..=1.0).contains(&decay_factor) {
366 panic!("decay_factor must be between 0 and 1");
367 }
368 self.decay_factor = decay_factor;
369 self
370 }
371
372 pub fn update(
374 &mut self,
375 features: &Array1<Float>,
376 target: Float,
377 prediction: Float,
378 ) -> SklResult<()> {
379 let n_features = features.len();
380
381 if let Some(expected) = self.n_features_ {
382 if n_features != expected {
383 return Err(SklearsError::InvalidInput(
384 "Inconsistent number of features".to_string(),
385 ));
386 }
387 } else {
388 self.n_features_ = Some(n_features);
389 }
390
391 self.sample_count_ += 1;
392 let prediction_error = (target - prediction).abs();
393
394 for (i, &feature_value) in features.iter().enumerate() {
396 let contribution = feature_value.abs() * prediction_error;
397
398 let current_importance = self.importance_scores_.get(&i).cloned().unwrap_or(0.0);
399 let alpha = 1.0 - self.decay_factor;
400 let new_importance = alpha * contribution + self.decay_factor * current_importance;
401
402 self.importance_scores_.insert(i, new_importance);
403 }
404
405 Ok(())
406 }
407
408 pub fn get_importance_scores(&self) -> &HashMap<usize, Float> {
410 &self.importance_scores_
411 }
412
413 pub fn get_top_features(&self, k: usize) -> Vec<usize> {
415 let mut scores: Vec<(usize, Float)> = self
416 .importance_scores_
417 .iter()
418 .map(|(&idx, &score)| (idx, score))
419 .collect();
420
421 scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
422 scores.into_iter().take(k).map(|(idx, _)| idx).collect()
423 }
424}
425
426impl Default for StreamingFeatureImportance {
427 fn default() -> Self {
428 Self::new()
429 }
430}
431
432#[derive(Debug, Clone)]
436pub struct ConceptDriftAwareSelector<State = Untrained> {
437 base_selector: OnlineFeatureSelector<State>,
438 drift_detection_window: usize,
439 drift_threshold: f64,
440 adaptation_rate: f64,
441
442 performance_history_: VecDeque<Float>,
444 drift_detected_: bool,
445}
446
447impl ConceptDriftAwareSelector<Untrained> {
448 pub fn new(k: usize) -> Self {
450 Self {
451 base_selector: OnlineFeatureSelector::new(k),
452 drift_detection_window: 100,
453 drift_threshold: 0.05,
454 adaptation_rate: 0.1,
455 performance_history_: VecDeque::new(),
456 drift_detected_: false,
457 }
458 }
459
460 pub fn drift_detection_window(mut self, window_size: usize) -> Self {
462 self.drift_detection_window = window_size;
463 self
464 }
465
466 pub fn drift_threshold(mut self, threshold: f64) -> Self {
468 self.drift_threshold = threshold;
469 self
470 }
471
472 pub fn min_samples(mut self, min_samples: usize) -> Self {
474 self.base_selector = self.base_selector.min_samples(min_samples);
475 self
476 }
477}
478
479impl Estimator for ConceptDriftAwareSelector<Untrained> {
480 type Config = ();
481 type Error = SklearsError;
482 type Float = f64;
483
484 fn config(&self) -> &Self::Config {
485 &()
486 }
487}
488
489impl Fit<Array2<Float>, Array1<Float>> for ConceptDriftAwareSelector<Untrained> {
490 type Fitted = ConceptDriftAwareSelector<Trained>;
491
492 fn fit(self, x: &Array2<Float>, y: &Array1<Float>) -> SklResult<Self::Fitted> {
493 let fitted_base = self.base_selector.fit(x, y)?;
494
495 Ok(ConceptDriftAwareSelector {
496 base_selector: fitted_base,
497 drift_detection_window: self.drift_detection_window,
498 drift_threshold: self.drift_threshold,
499 adaptation_rate: self.adaptation_rate,
500 performance_history_: VecDeque::new(),
501 drift_detected_: false,
502 })
503 }
504}
505
506impl ConceptDriftAwareSelector<Trained> {
507 pub fn partial_fit_with_performance(
509 &mut self,
510 sample: &Array1<Float>,
511 target: Float,
512 performance: Float,
513 ) -> SklResult<()> {
514 self.base_selector.partial_fit_sample(sample, target)?;
516
517 self.performance_history_.push_back(performance);
519 if self.performance_history_.len() > self.drift_detection_window {
520 self.performance_history_.pop_front();
521 }
522
523 if self.performance_history_.len() >= self.drift_detection_window / 2 {
525 self.drift_detected_ = self.detect_performance_drift()?;
526
527 if self.drift_detected_ {
528 self.adapt_to_drift()?;
530 }
531 }
532
533 Ok(())
534 }
535
536 fn detect_performance_drift(&self) -> SklResult<bool> {
538 if self.performance_history_.len() < 20 {
539 return Ok(false);
540 }
541
542 let mid = self.performance_history_.len() / 2;
543 let recent_perf: Float = self.performance_history_.iter().skip(mid).sum::<Float>()
544 / (self.performance_history_.len() - mid) as Float;
545
546 let old_perf: Float =
547 self.performance_history_.iter().take(mid).sum::<Float>() / mid as Float;
548
549 Ok(old_perf - recent_perf > self.drift_threshold)
551 }
552
553 fn adapt_to_drift(&mut self) -> SklResult<()> {
555 let reset_fraction = self.adaptation_rate;
560 let samples_to_keep =
561 ((1.0 - reset_fraction) * self.performance_history_.len() as f64) as usize;
562
563 while self.performance_history_.len() > samples_to_keep {
564 self.performance_history_.pop_front();
565 }
566
567 self.drift_detected_ = false;
568 Ok(())
569 }
570
571 pub fn drift_detected(&self) -> bool {
573 self.drift_detected_
574 }
575}
576
577impl FeatureSelector for ConceptDriftAwareSelector<Trained> {
578 fn selected_features(&self) -> &Vec<usize> {
579 self.base_selector.selected_features()
580 }
581}
582
583impl Transform<Array2<Float>, Array2<Float>> for ConceptDriftAwareSelector<Trained> {
584 fn transform(&self, x: &Array2<Float>) -> SklResult<Array2<Float>> {
585 self.base_selector.transform(x)
586 }
587}
588
589#[allow(non_snake_case)]
590#[cfg(test)]
591mod tests {
592 use super::*;
593
594 use scirs2_core::ndarray::array;
595
596 #[test]
597 fn test_online_feature_selector_basic() {
598 let x = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]];
599 let y = array![1.0, 2.0, 3.0];
600
601 let selector = OnlineFeatureSelector::new(2).min_samples(2);
602 let fitted = selector.fit(&x, &y).unwrap();
603
604 assert_eq!(fitted.selected_features().len(), 2);
605 assert_eq!(fitted.sample_count_, 3);
606 }
607
608 #[test]
609 fn test_online_selector_partial_fit() {
610 let x = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]];
611 let y = array![1.0, 2.0];
612
613 let selector = OnlineFeatureSelector::new(2);
614 let mut fitted = selector.fit(&x, &y).unwrap();
615
616 let new_sample = array![10.0, 11.0, 12.0];
618 fitted.partial_fit_sample(&new_sample, 3.0).unwrap();
619
620 assert_eq!(fitted.sample_count_, 3);
621 }
622
623 #[test]
624 fn test_streaming_importance() {
625 let mut importance = StreamingFeatureImportance::new();
626
627 let features = array![1.0, 2.0, 3.0];
628 importance.update(&features, 5.0, 4.8).unwrap();
629
630 let scores = importance.get_importance_scores();
631 assert_eq!(scores.len(), 3);
632
633 let top_features = importance.get_top_features(2);
634 assert_eq!(top_features.len(), 2);
635 }
636
637 #[test]
638 fn test_concept_drift_selector() {
639 let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
640 let y = array![1.0, 2.0, 3.0];
641
642 let selector = ConceptDriftAwareSelector::new(1).min_samples(2);
643 let mut fitted = selector.fit(&x, &y).unwrap();
644
645 let sample = array![7.0, 8.0];
647 fitted
648 .partial_fit_with_performance(&sample, 4.0, 0.9)
649 .unwrap();
650
651 assert_eq!(fitted.selected_features().len(), 1);
652 }
653
654 #[test]
655 fn test_online_selector_transform() {
656 let x = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]];
657 let y = array![1.0, 2.0];
658
659 let selector = OnlineFeatureSelector::new(2).min_samples(2);
660 let fitted = selector.fit(&x, &y).unwrap();
661
662 let test_x = array![[7.0, 8.0, 9.0], [10.0, 11.0, 12.0]];
663 let transformed = fitted.transform(&test_x).unwrap();
664
665 assert_eq!(transformed.ncols(), 2);
666 assert_eq!(transformed.nrows(), 2);
667 }
668}