sklears_compose/ensemble/
dynamic_selection.rs

1//! Dynamic Ensemble Selection
2//!
3//! This module provides dynamic ensemble selection strategies that adaptively
4//! choose the best subset of estimators for each prediction based on local competence.
5
6use crate::PipelinePredictor;
7use scirs2_core::ndarray::{s, Array1, Array2, ArrayView1, ArrayView2};
8use sklears_core::{
9    error::Result as SklResult,
10    prelude::SklearsError,
11    traits::{Estimator, Fit, Untrained},
12    types::Float,
13};
14
15/// Dynamic Ensemble Selection strategy
16#[derive(Debug, Clone)]
17pub enum SelectionStrategy {
18    /// Select k best classifiers for each sample
19    KBest { k: usize },
20    /// Select classifiers based on competence threshold
21    Threshold { threshold: f64 },
22    /// Select all classifiers above median performance
23    AboveMedian,
24    /// Select based on local competence estimation
25    LocalCompetence { k_neighbors: usize },
26}
27
28/// Competence estimation method
29#[derive(Debug, Clone)]
30pub enum CompetenceEstimation {
31    /// Accuracy in local region
32    LocalAccuracy,
33    /// Distance to decision boundary
34    DecisionBoundary,
35    /// Entropy-based competence
36    Entropy,
37    /// Margin-based competence
38    Margin,
39}
40
41/// Dynamic Ensemble Selector
42///
43/// Dynamically selects the best subset of classifiers for each prediction
44/// based on local competence estimation.
45///
46/// # Examples
47///
48/// ```ignore
49/// use sklears_compose::{DynamicEnsembleSelector, MockPredictor, SelectionStrategy};
50/// use scirs2_core::ndarray::array;
51///
52/// let selector = DynamicEnsembleSelector::builder()
53///     .estimator("clf1", Box::new(MockPredictor::new()))
54///     .estimator("clf2", Box::new(MockPredictor::new()))
55///     .selection_strategy(SelectionStrategy::KBest { k: 2 })
56///     .build();
57/// ```
58pub struct DynamicEnsembleSelector<S = Untrained> {
59    state: S,
60    estimators: Vec<(String, Box<dyn PipelinePredictor>)>,
61    selection_strategy: SelectionStrategy,
62    competence_estimation: CompetenceEstimation,
63    validation_split: f64,
64    n_jobs: Option<i32>,
65}
66
67/// Trained state for `DynamicEnsembleSelector`
68pub struct DynamicEnsembleSelectorTrained {
69    fitted_estimators: Vec<(String, Box<dyn PipelinePredictor>)>,
70    validation_data: Array2<f64>,
71    validation_targets: Array1<f64>,
72    competence_scores: Vec<Vec<f64>>, // competence for each estimator on validation set
73    n_features_in: usize,
74    feature_names_in: Option<Vec<String>>,
75}
76
77impl DynamicEnsembleSelector<Untrained> {
78    /// Create a new `DynamicEnsembleSelector`
79    #[must_use]
80    pub fn new() -> Self {
81        Self {
82            state: Untrained,
83            estimators: Vec::new(),
84            selection_strategy: SelectionStrategy::KBest { k: 3 },
85            competence_estimation: CompetenceEstimation::LocalAccuracy,
86            validation_split: 0.2,
87            n_jobs: None,
88        }
89    }
90
91    /// Create a dynamic ensemble selector builder
92    #[must_use]
93    pub fn builder() -> DynamicEnsembleSelectorBuilder {
94        DynamicEnsembleSelectorBuilder::new()
95    }
96
97    /// Add an estimator
98    pub fn add_estimator(&mut self, name: String, estimator: Box<dyn PipelinePredictor>) {
99        self.estimators.push((name, estimator));
100    }
101
102    /// Set selection strategy
103    #[must_use]
104    pub fn selection_strategy(mut self, strategy: SelectionStrategy) -> Self {
105        self.selection_strategy = strategy;
106        self
107    }
108
109    /// Set competence estimation method
110    #[must_use]
111    pub fn competence_estimation(mut self, method: CompetenceEstimation) -> Self {
112        self.competence_estimation = method;
113        self
114    }
115
116    /// Set validation split ratio
117    #[must_use]
118    pub fn validation_split(mut self, split: f64) -> Self {
119        self.validation_split = split.clamp(0.1, 0.5);
120        self
121    }
122
123    /// Set number of jobs
124    #[must_use]
125    pub fn n_jobs(mut self, n_jobs: Option<i32>) -> Self {
126        self.n_jobs = n_jobs;
127        self
128    }
129}
130
131impl Default for DynamicEnsembleSelector<Untrained> {
132    fn default() -> Self {
133        Self::new()
134    }
135}
136
137impl Estimator for DynamicEnsembleSelector<Untrained> {
138    type Config = ();
139    type Error = SklearsError;
140    type Float = Float;
141
142    fn config(&self) -> &Self::Config {
143        &()
144    }
145}
146
147impl Fit<ArrayView2<'_, Float>, Option<&ArrayView1<'_, Float>>>
148    for DynamicEnsembleSelector<Untrained>
149{
150    type Fitted = DynamicEnsembleSelector<DynamicEnsembleSelectorTrained>;
151
152    fn fit(
153        self,
154        x: &ArrayView2<'_, Float>,
155        y: &Option<&ArrayView1<'_, Float>>,
156    ) -> SklResult<Self::Fitted> {
157        if let Some(y_values) = y.as_ref() {
158            if self.estimators.is_empty() {
159                return Err(SklearsError::InvalidInput(
160                    "At least one estimator must be provided".to_string(),
161                ));
162            }
163
164            let n_samples = x.nrows();
165            let validation_size = (n_samples as f64 * self.validation_split).max(1.0) as usize;
166            let train_size = n_samples - validation_size;
167
168            // Split data for training and validation
169            let x_train = x.slice(s![..train_size, ..]);
170            let y_train = y_values.slice(s![..train_size]);
171            let x_val = x.slice(s![train_size.., ..]);
172            let y_val = y_values.slice(s![train_size..]);
173
174            // Train all estimators on training set
175            let mut fitted_estimators = Vec::new();
176            let estimators: Vec<(String, Box<dyn PipelinePredictor>)> = self
177                .estimators
178                .iter()
179                .map(|(name, estimator)| (name.clone(), estimator.clone_predictor()))
180                .collect();
181            for (name, mut estimator) in estimators {
182                estimator.fit(&x_train, &y_train)?;
183                fitted_estimators.push((name, estimator));
184            }
185
186            // Compute competence scores on validation set
187            let competence_scores =
188                self.compute_competence_scores(&fitted_estimators, &x_val, &y_val)?;
189
190            Ok(DynamicEnsembleSelector {
191                state: DynamicEnsembleSelectorTrained {
192                    fitted_estimators,
193                    validation_data: x_val.mapv(|v| v),
194                    validation_targets: y_val.mapv(|v| v),
195                    competence_scores,
196                    n_features_in: x.ncols(),
197                    feature_names_in: None,
198                },
199                estimators: Vec::new(),
200                selection_strategy: self.selection_strategy,
201                competence_estimation: self.competence_estimation,
202                validation_split: self.validation_split,
203                n_jobs: self.n_jobs,
204            })
205        } else {
206            Err(SklearsError::InvalidInput(
207                "Target values required for fitting".to_string(),
208            ))
209        }
210    }
211}
212
213impl DynamicEnsembleSelector<Untrained> {
214    /// Compute competence scores for all estimators on validation set
215    fn compute_competence_scores(
216        &self,
217        estimators: &[(String, Box<dyn PipelinePredictor>)],
218        x_val: &ArrayView2<'_, Float>,
219        y_val: &ArrayView1<'_, Float>,
220    ) -> SklResult<Vec<Vec<f64>>> {
221        let mut competence_scores = Vec::new();
222
223        for (_, estimator) in estimators {
224            let predictions = estimator.predict(x_val)?;
225            let scores = match self.competence_estimation {
226                CompetenceEstimation::LocalAccuracy => {
227                    self.compute_local_accuracy(&predictions, y_val)?
228                }
229                CompetenceEstimation::DecisionBoundary => {
230                    self.compute_decision_boundary_competence(&predictions, y_val)?
231                }
232                CompetenceEstimation::Entropy => self.compute_entropy_competence(&predictions)?,
233                CompetenceEstimation::Margin => {
234                    self.compute_margin_competence(&predictions, y_val)?
235                }
236            };
237            competence_scores.push(scores);
238        }
239
240        Ok(competence_scores)
241    }
242
243    /// Compute local accuracy competence
244    fn compute_local_accuracy(
245        &self,
246        predictions: &Array1<f64>,
247        y_true: &ArrayView1<'_, Float>,
248    ) -> SklResult<Vec<f64>> {
249        let mut scores = Vec::new();
250
251        for i in 0..predictions.len() {
252            let pred = predictions[i];
253            let true_val = y_true[i];
254
255            // For classification: exact match, for regression: inverse of absolute error
256            let accuracy = if (pred - pred.round()).abs() < 1e-6
257                && (true_val - true_val.round()).abs() < 1e-6
258            {
259                // Classification case
260                if (pred.round() - true_val.round()).abs() < 1e-6 {
261                    1.0
262                } else {
263                    0.0
264                }
265            } else {
266                // Regression case: inverse of absolute error
267                1.0 / (1.0 + (pred - true_val).abs())
268            };
269
270            scores.push(accuracy);
271        }
272
273        Ok(scores)
274    }
275
276    /// Compute decision boundary competence (simplified)
277    fn compute_decision_boundary_competence(
278        &self,
279        predictions: &Array1<f64>,
280        _y_true: &ArrayView1<'_, Float>,
281    ) -> SklResult<Vec<f64>> {
282        // For simplicity, use prediction confidence as proxy for distance to boundary
283        let mut scores = Vec::new();
284
285        for &pred in predictions {
286            // For classification: distance from 0.5 decision boundary
287            // For regression: use inverse of prediction magnitude
288            let confidence = if (0.0..=1.0).contains(&pred) {
289                // Classification probability
290                (pred - 0.5).abs() * 2.0
291            } else {
292                // Regression: normalize by prediction magnitude
293                1.0 / (1.0 + pred.abs())
294            };
295
296            scores.push(confidence);
297        }
298
299        Ok(scores)
300    }
301
302    /// Compute entropy-based competence
303    fn compute_entropy_competence(&self, predictions: &Array1<f64>) -> SklResult<Vec<f64>> {
304        let mut scores = Vec::new();
305
306        for &pred in predictions {
307            // For classification: entropy of prediction probabilities
308            // For regression: use prediction variance as proxy
309            let entropy = if (0.0..=1.0).contains(&pred) {
310                // Binary classification entropy
311                let p = pred.clamp(1e-10, 1.0 - 1e-10);
312                -(p * p.ln() + (1.0 - p) * (1.0 - p).ln())
313            } else {
314                // Regression: use inverse of squared prediction
315                1.0 / (1.0 + pred.powi(2))
316            };
317
318            scores.push(1.0 - entropy); // Higher competence = lower entropy
319        }
320
321        Ok(scores)
322    }
323
324    /// Compute margin-based competence
325    fn compute_margin_competence(
326        &self,
327        predictions: &Array1<f64>,
328        _y_true: &ArrayView1<'_, Float>,
329    ) -> SklResult<Vec<f64>> {
330        // Simplified margin calculation
331        let mut scores = Vec::new();
332
333        for &pred in predictions {
334            // For classification: margin from decision boundary
335            // For regression: inverse of prediction uncertainty
336            let margin = if (0.0..=1.0).contains(&pred) {
337                // Classification: distance from 0.5
338                (pred - 0.5).abs()
339            } else {
340                // Regression: use prediction confidence
341                1.0 / (1.0 + pred.abs())
342            };
343
344            scores.push(margin);
345        }
346
347        Ok(scores)
348    }
349}
350
351impl DynamicEnsembleSelector<DynamicEnsembleSelectorTrained> {
352    /// Predict using dynamic ensemble selection
353    pub fn predict(&self, x: &ArrayView2<'_, Float>) -> SklResult<Array1<f64>> {
354        let mut predictions = Vec::new();
355
356        for i in 0..x.nrows() {
357            let sample = x.row(i);
358            let selected_indices = self.select_estimators_for_sample(sample, i)?;
359
360            // Get predictions from selected estimators
361            let mut sample_predictions = Vec::new();
362            for idx in selected_indices {
363                let pred = self.state.fitted_estimators[idx]
364                    .1
365                    .predict(&x.slice(s![i..i + 1, ..]))?;
366                sample_predictions.push(pred[0]);
367            }
368
369            // Combine predictions (simple averaging)
370            let final_prediction = if sample_predictions.is_empty() {
371                0.0 // Fallback
372            } else {
373                sample_predictions.iter().sum::<f64>() / sample_predictions.len() as f64
374            };
375
376            predictions.push(final_prediction);
377        }
378
379        Ok(Array1::from_vec(predictions))
380    }
381
382    /// Select estimators for a specific sample
383    fn select_estimators_for_sample(
384        &self,
385        sample: ArrayView1<'_, Float>,
386        sample_idx: usize,
387    ) -> SklResult<Vec<usize>> {
388        match &self.selection_strategy {
389            SelectionStrategy::KBest { k } => self.select_k_best_estimators(*k, sample_idx),
390            SelectionStrategy::Threshold { threshold } => {
391                self.select_by_threshold(*threshold, sample_idx)
392            }
393            SelectionStrategy::AboveMedian => self.select_above_median(sample_idx),
394            SelectionStrategy::LocalCompetence { k_neighbors } => {
395                self.select_by_local_competence(&sample, *k_neighbors)
396            }
397        }
398    }
399
400    /// Select k best estimators based on competence scores
401    fn select_k_best_estimators(&self, k: usize, sample_idx: usize) -> SklResult<Vec<usize>> {
402        let mut estimator_scores: Vec<(usize, f64)> = self
403            .state
404            .competence_scores
405            .iter()
406            .enumerate()
407            .map(|(i, scores)| {
408                let score = scores
409                    .get(sample_idx % scores.len())
410                    .copied()
411                    .unwrap_or(0.0);
412                (i, score)
413            })
414            .collect();
415
416        estimator_scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
417
418        let selected_k = k.min(estimator_scores.len());
419        Ok(estimator_scores
420            .into_iter()
421            .take(selected_k)
422            .map(|(idx, _)| idx)
423            .collect())
424    }
425
426    /// Select estimators above threshold
427    fn select_by_threshold(&self, threshold: f64, sample_idx: usize) -> SklResult<Vec<usize>> {
428        let selected: Vec<usize> = self
429            .state
430            .competence_scores
431            .iter()
432            .enumerate()
433            .filter_map(|(i, scores)| {
434                let score = scores
435                    .get(sample_idx % scores.len())
436                    .copied()
437                    .unwrap_or(0.0);
438                if score >= threshold {
439                    Some(i)
440                } else {
441                    None
442                }
443            })
444            .collect();
445
446        if selected.is_empty() {
447            // Fallback: select best estimator
448            self.select_k_best_estimators(1, sample_idx)
449        } else {
450            Ok(selected)
451        }
452    }
453
454    /// Select estimators above median performance
455    fn select_above_median(&self, sample_idx: usize) -> SklResult<Vec<usize>> {
456        let scores: Vec<f64> = self
457            .state
458            .competence_scores
459            .iter()
460            .map(|scores| {
461                scores
462                    .get(sample_idx % scores.len())
463                    .copied()
464                    .unwrap_or(0.0)
465            })
466            .collect();
467
468        let mut sorted_scores = scores.clone();
469        sorted_scores.sort_by(|a, b| a.partial_cmp(b).unwrap());
470        let median = sorted_scores[sorted_scores.len() / 2];
471
472        let selected: Vec<usize> = scores
473            .iter()
474            .enumerate()
475            .filter_map(|(i, &score)| if score >= median { Some(i) } else { None })
476            .collect();
477
478        if selected.is_empty() {
479            self.select_k_best_estimators(1, sample_idx)
480        } else {
481            Ok(selected)
482        }
483    }
484
485    /// Select estimators based on local competence (k-nearest neighbors)
486    fn select_by_local_competence(
487        &self,
488        sample: &ArrayView1<'_, Float>,
489        k_neighbors: usize,
490    ) -> SklResult<Vec<usize>> {
491        // Find k nearest neighbors in validation set
492        let mut distances: Vec<(usize, f64)> = self
493            .state
494            .validation_data
495            .rows()
496            .into_iter()
497            .enumerate()
498            .map(|(i, val_sample)| {
499                let dist = self.euclidean_distance(*sample, val_sample.mapv(|v| v as Float).view());
500                (i, dist)
501            })
502            .collect();
503
504        distances.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
505        let neighbor_indices: Vec<usize> = distances
506            .into_iter()
507            .take(k_neighbors.min(self.state.validation_data.nrows()))
508            .map(|(idx, _)| idx)
509            .collect();
510
511        // Compute average competence for each estimator in local region
512        let mut local_competences = vec![0.0; self.state.fitted_estimators.len()];
513        for (est_idx, scores) in self.state.competence_scores.iter().enumerate() {
514            let avg_competence = neighbor_indices
515                .iter()
516                .map(|&ni| scores.get(ni).copied().unwrap_or(0.0))
517                .sum::<f64>()
518                / neighbor_indices.len() as f64;
519            local_competences[est_idx] = avg_competence;
520        }
521
522        // Select estimators above average local competence
523        let avg_local_competence =
524            local_competences.iter().sum::<f64>() / local_competences.len() as f64;
525        let selected: Vec<usize> = local_competences
526            .iter()
527            .enumerate()
528            .filter_map(|(i, &comp)| {
529                if comp >= avg_local_competence {
530                    Some(i)
531                } else {
532                    None
533                }
534            })
535            .collect();
536
537        if selected.is_empty() {
538            // Fallback: select best local estimator
539            let best_idx = local_competences
540                .iter()
541                .enumerate()
542                .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
543                .map_or(0, |(idx, _)| idx);
544            Ok(vec![best_idx])
545        } else {
546            Ok(selected)
547        }
548    }
549
550    /// Compute Euclidean distance between two samples
551    fn euclidean_distance(&self, a: ArrayView1<'_, Float>, b: ArrayView1<'_, Float>) -> f64 {
552        a.iter()
553            .zip(b.iter())
554            .map(|(&x, &y)| (x - y).powi(2))
555            .sum::<f64>()
556            .sqrt()
557    }
558
559    /// Get fitted estimators
560    #[must_use]
561    pub fn estimators(&self) -> &[(String, Box<dyn PipelinePredictor>)] {
562        &self.state.fitted_estimators
563    }
564}
565
566/// `DynamicEnsembleSelector` builder for fluent construction
567pub struct DynamicEnsembleSelectorBuilder {
568    estimators: Vec<(String, Box<dyn PipelinePredictor>)>,
569    selection_strategy: SelectionStrategy,
570    competence_estimation: CompetenceEstimation,
571    validation_split: f64,
572    n_jobs: Option<i32>,
573}
574
575impl DynamicEnsembleSelectorBuilder {
576    /// Create a new builder
577    #[must_use]
578    pub fn new() -> Self {
579        Self {
580            estimators: Vec::new(),
581            selection_strategy: SelectionStrategy::KBest { k: 3 },
582            competence_estimation: CompetenceEstimation::LocalAccuracy,
583            validation_split: 0.2,
584            n_jobs: None,
585        }
586    }
587
588    /// Add an estimator
589    #[must_use]
590    pub fn estimator(mut self, name: &str, estimator: Box<dyn PipelinePredictor>) -> Self {
591        self.estimators.push((name.to_string(), estimator));
592        self
593    }
594
595    /// Set selection strategy
596    #[must_use]
597    pub fn selection_strategy(mut self, strategy: SelectionStrategy) -> Self {
598        self.selection_strategy = strategy;
599        self
600    }
601
602    /// Set competence estimation method
603    #[must_use]
604    pub fn competence_estimation(mut self, method: CompetenceEstimation) -> Self {
605        self.competence_estimation = method;
606        self
607    }
608
609    /// Set validation split ratio
610    #[must_use]
611    pub fn validation_split(mut self, split: f64) -> Self {
612        self.validation_split = split.clamp(0.1, 0.5);
613        self
614    }
615
616    /// Set number of jobs
617    #[must_use]
618    pub fn n_jobs(mut self, n_jobs: Option<i32>) -> Self {
619        self.n_jobs = n_jobs;
620        self
621    }
622
623    /// Build the `DynamicEnsembleSelector`
624    #[must_use]
625    pub fn build(self) -> DynamicEnsembleSelector<Untrained> {
626        DynamicEnsembleSelector {
627            state: Untrained,
628            estimators: self.estimators,
629            selection_strategy: self.selection_strategy,
630            competence_estimation: self.competence_estimation,
631            validation_split: self.validation_split,
632            n_jobs: self.n_jobs,
633        }
634    }
635}
636
637impl Default for DynamicEnsembleSelectorBuilder {
638    fn default() -> Self {
639        Self::new()
640    }
641}