sklears_ensemble/voting/
ensemble.rs

1//! Ensemble member management and traits for voting classifiers
2
3use scirs2_core::ndarray::{Array1, Array2};
4use sklears_core::{error::Result, types::Float};
5
6/// Trait for estimators that can be used in ensemble
7pub trait EnsembleMember {
8    /// Get estimator weight in the ensemble
9    fn weight(&self) -> Float;
10
11    /// Set estimator weight
12    fn set_weight(&mut self, weight: Float);
13
14    /// Get estimator performance metric
15    fn performance(&self) -> Float;
16
17    /// Update performance metric
18    fn update_performance(&mut self, performance: Float);
19
20    /// Get prediction confidence
21    fn confidence(&self) -> Float;
22
23    /// Make predictions on input data
24    fn predict(&self, x: &Array2<Float>) -> Result<Array1<Float>>;
25
26    /// Make probability predictions (if supported)
27    fn predict_proba(&self, x: &Array2<Float>) -> Result<Array2<Float>>;
28
29    /// Check if estimator supports probability predictions
30    fn supports_proba(&self) -> bool;
31
32    /// Get feature importance (if available)
33    fn feature_importance(&self) -> Option<Array1<Float>>;
34
35    /// Get model complexity measure
36    fn complexity(&self) -> Float;
37
38    /// Check if model is fitted
39    fn is_fitted(&self) -> bool;
40
41    /// Get number of classes (for classifiers)
42    fn n_classes(&self) -> Option<usize>;
43
44    /// Get number of features expected
45    fn n_features(&self) -> Option<usize>;
46
47    /// Calculate prediction uncertainty
48    fn uncertainty(&self, x: &Array2<Float>) -> Result<Array1<Float>>;
49
50    /// Get model name/identifier
51    fn name(&self) -> String;
52
53    /// Clone the estimator (for ensemble operations)
54    fn clone_estimator(&self) -> Box<dyn EnsembleMember + Send + Sync>;
55}
56
57/// Mock estimator for testing ensemble functionality
58#[derive(Debug, Clone)]
59pub struct MockEstimator {
60    weight: Float,
61    performance: Float,
62    confidence: Float,
63    bias: Float,
64    supports_proba: bool,
65    is_fitted: bool,
66    n_classes: Option<usize>,
67    n_features: Option<usize>,
68    name: String,
69}
70
71impl MockEstimator {
72    pub fn new(bias: Float) -> Self {
73        Self {
74            weight: 1.0,
75            performance: 0.8,
76            confidence: 0.9,
77            bias,
78            supports_proba: true,
79            is_fitted: true,
80            n_classes: Some(2),
81            n_features: Some(2),
82            name: format!("MockEstimator_{}", bias),
83        }
84    }
85
86    pub fn with_weight(mut self, weight: Float) -> Self {
87        self.weight = weight;
88        self
89    }
90
91    pub fn with_performance(mut self, performance: Float) -> Self {
92        self.performance = performance;
93        self
94    }
95
96    pub fn with_confidence(mut self, confidence: Float) -> Self {
97        self.confidence = confidence;
98        self
99    }
100
101    pub fn with_proba_support(mut self, supports: bool) -> Self {
102        self.supports_proba = supports;
103        self
104    }
105
106    pub fn with_fitted_status(mut self, fitted: bool) -> Self {
107        self.is_fitted = fitted;
108        self
109    }
110
111    pub fn with_classes(mut self, n_classes: usize) -> Self {
112        self.n_classes = Some(n_classes);
113        self
114    }
115
116    pub fn with_features(mut self, n_features: usize) -> Self {
117        self.n_features = Some(n_features);
118        self
119    }
120
121    pub fn with_name(mut self, name: String) -> Self {
122        self.name = name;
123        self
124    }
125}
126
127impl EnsembleMember for MockEstimator {
128    fn weight(&self) -> Float {
129        self.weight
130    }
131
132    fn set_weight(&mut self, weight: Float) {
133        self.weight = weight;
134    }
135
136    fn performance(&self) -> Float {
137        self.performance
138    }
139
140    fn update_performance(&mut self, performance: Float) {
141        self.performance = performance;
142    }
143
144    fn confidence(&self) -> Float {
145        self.confidence
146    }
147
148    fn predict(&self, x: &Array2<Float>) -> Result<Array1<Float>> {
149        if !self.is_fitted {
150            return Err(sklears_core::error::SklearsError::NotFitted {
151                operation: "predict".to_string(),
152            });
153        }
154
155        let n_samples = x.nrows();
156        let mut predictions = Array1::zeros(n_samples);
157
158        // Simple mock prediction: bias towards a specific class based on features
159        for i in 0..n_samples {
160            let feature_sum: Float = x.row(i).sum();
161            let prediction = if feature_sum + self.bias > 0.0 {
162                1.0
163            } else {
164                0.0
165            };
166            predictions[i] = prediction;
167        }
168
169        Ok(predictions)
170    }
171
172    fn predict_proba(&self, x: &Array2<Float>) -> Result<Array2<Float>> {
173        if !self.supports_proba {
174            return Err(sklears_core::error::SklearsError::InvalidOperation(
175                "Estimator does not support probability predictions".to_string(),
176            ));
177        }
178
179        if !self.is_fitted {
180            return Err(sklears_core::error::SklearsError::NotFitted {
181                operation: "predict".to_string(),
182            });
183        }
184
185        let n_samples = x.nrows();
186        let n_classes = self.n_classes.unwrap_or(2);
187        let mut probabilities = Array2::zeros((n_samples, n_classes));
188
189        // Simple mock probability: sigmoid-like transformation of features
190        for i in 0..n_samples {
191            let feature_sum: Float = x.row(i).sum();
192            let logit = feature_sum + self.bias;
193
194            if n_classes == 2 {
195                // Binary classification
196                let prob_class_1 = 1.0 / (1.0 + (-logit).exp());
197                probabilities[[i, 0]] = 1.0 - prob_class_1;
198                probabilities[[i, 1]] = prob_class_1;
199            } else {
200                // Multi-class: uniform distribution for simplicity
201                let prob_per_class = 1.0 / n_classes as Float;
202                for j in 0..n_classes {
203                    probabilities[[i, j]] = prob_per_class;
204                }
205                // Add some bias to the first class
206                probabilities[[i, 0]] += self.bias * 0.1;
207
208                // Normalize to ensure sum = 1
209                let row_sum: Float = probabilities.row(i).sum();
210                if row_sum > 0.0 {
211                    for j in 0..n_classes {
212                        probabilities[[i, j]] /= row_sum;
213                    }
214                }
215            }
216        }
217
218        Ok(probabilities)
219    }
220
221    fn supports_proba(&self) -> bool {
222        self.supports_proba
223    }
224
225    fn feature_importance(&self) -> Option<Array1<Float>> {
226        if let Some(n_features) = self.n_features {
227            // Mock feature importance: uniform importance with some bias
228            let mut importance = Array1::ones(n_features) / n_features as Float;
229            if n_features > 0 {
230                importance[0] += self.bias.abs() * 0.1; // First feature gets extra importance
231            }
232
233            // Normalize
234            let total: Float = importance.sum();
235            if total > 0.0 {
236                importance.mapv_inplace(|x| x / total);
237            }
238
239            Some(importance)
240        } else {
241            None
242        }
243    }
244
245    fn complexity(&self) -> Float {
246        // Mock complexity based on bias magnitude
247        self.bias.abs() + 1.0
248    }
249
250    fn is_fitted(&self) -> bool {
251        self.is_fitted
252    }
253
254    fn n_classes(&self) -> Option<usize> {
255        self.n_classes
256    }
257
258    fn n_features(&self) -> Option<usize> {
259        self.n_features
260    }
261
262    fn uncertainty(&self, x: &Array2<Float>) -> Result<Array1<Float>> {
263        if !self.is_fitted {
264            return Err(sklears_core::error::SklearsError::NotFitted {
265                operation: "predict".to_string(),
266            });
267        }
268
269        let n_samples = x.nrows();
270        let mut uncertainty = Array1::zeros(n_samples);
271
272        // Mock uncertainty: higher for samples far from decision boundary
273        for i in 0..n_samples {
274            let feature_sum: Float = x.row(i).sum();
275            let logit = feature_sum + self.bias;
276
277            // Uncertainty is higher when logit is close to 0 (decision boundary)
278            let prob = 1.0 / (1.0 + (-logit).exp());
279            let entropy = -prob * prob.ln() - (1.0 - prob) * (1.0 - prob).ln();
280            uncertainty[i] = entropy;
281        }
282
283        Ok(uncertainty)
284    }
285
286    fn name(&self) -> String {
287        self.name.clone()
288    }
289
290    fn clone_estimator(&self) -> Box<dyn EnsembleMember + Send + Sync> {
291        Box::new(self.clone())
292    }
293}
294
295/// Ensemble member wrapper for external estimators
296#[derive(Debug)]
297pub struct ExternalEstimatorWrapper {
298    weight: Float,
299    performance: Float,
300    confidence: Float,
301    name: String,
302}
303
304impl ExternalEstimatorWrapper {
305    pub fn new(name: String) -> Self {
306        Self {
307            weight: 1.0,
308            performance: 0.0,
309            confidence: 0.5,
310            name,
311        }
312    }
313}
314
315impl EnsembleMember for ExternalEstimatorWrapper {
316    fn weight(&self) -> Float {
317        self.weight
318    }
319
320    fn set_weight(&mut self, weight: Float) {
321        self.weight = weight;
322    }
323
324    fn performance(&self) -> Float {
325        self.performance
326    }
327
328    fn update_performance(&mut self, performance: Float) {
329        self.performance = performance;
330    }
331
332    fn confidence(&self) -> Float {
333        self.confidence
334    }
335
336    fn predict(&self, _x: &Array2<Float>) -> Result<Array1<Float>> {
337        Err(sklears_core::error::SklearsError::NotImplemented(
338            "External estimator prediction not implemented".to_string(),
339        ))
340    }
341
342    fn predict_proba(&self, _x: &Array2<Float>) -> Result<Array2<Float>> {
343        Err(sklears_core::error::SklearsError::NotImplemented(
344            "External estimator probability prediction not implemented".to_string(),
345        ))
346    }
347
348    fn supports_proba(&self) -> bool {
349        false
350    }
351
352    fn feature_importance(&self) -> Option<Array1<Float>> {
353        None
354    }
355
356    fn complexity(&self) -> Float {
357        1.0
358    }
359
360    fn is_fitted(&self) -> bool {
361        true
362    }
363
364    fn n_classes(&self) -> Option<usize> {
365        None
366    }
367
368    fn n_features(&self) -> Option<usize> {
369        None
370    }
371
372    fn uncertainty(&self, _x: &Array2<Float>) -> Result<Array1<Float>> {
373        Err(sklears_core::error::SklearsError::NotImplemented(
374            "External estimator uncertainty estimation not implemented".to_string(),
375        ))
376    }
377
378    fn name(&self) -> String {
379        self.name.clone()
380    }
381
382    fn clone_estimator(&self) -> Box<dyn EnsembleMember + Send + Sync> {
383        Box::new(Self {
384            weight: self.weight,
385            performance: self.performance,
386            confidence: self.confidence,
387            name: self.name.clone(),
388        })
389    }
390}
391
392/// Utility functions for ensemble management
393pub mod ensemble_utils {
394    use super::*;
395
396    /// Calculate ensemble diversity using prediction disagreement
397    pub fn calculate_ensemble_diversity(
398        estimators: &[Box<dyn EnsembleMember + Send + Sync>],
399        x: &Array2<Float>,
400    ) -> Result<Float> {
401        if estimators.len() < 2 {
402            return Ok(0.0);
403        }
404
405        let n_samples = x.nrows();
406        let n_estimators = estimators.len();
407
408        // Collect all predictions
409        let mut all_predictions = Vec::new();
410        for estimator in estimators {
411            let predictions = estimator.predict(x)?;
412            all_predictions.push(predictions);
413        }
414
415        // Calculate pairwise disagreements
416        let mut total_disagreement = 0.0;
417        let mut n_pairs = 0;
418
419        for i in 0..n_estimators {
420            for j in (i + 1)..n_estimators {
421                let mut disagreements = 0;
422                for sample_idx in 0..n_samples {
423                    if (all_predictions[i][sample_idx] - all_predictions[j][sample_idx]).abs()
424                        > 1e-6
425                    {
426                        disagreements += 1;
427                    }
428                }
429                total_disagreement += disagreements as Float / n_samples as Float;
430                n_pairs += 1;
431            }
432        }
433
434        if n_pairs > 0 {
435            Ok(total_disagreement / n_pairs as Float)
436        } else {
437            Ok(0.0)
438        }
439    }
440
441    /// Update ensemble weights based on recent performance
442    pub fn update_ensemble_weights(
443        estimators: &mut [Box<dyn EnsembleMember + Send + Sync>],
444        recent_performances: &[Float],
445        learning_rate: Float,
446    ) {
447        if estimators.len() != recent_performances.len() {
448            return;
449        }
450
451        // Calculate performance-based weights
452        let total_performance: Float = recent_performances.iter().sum();
453
454        if total_performance > 1e-8 {
455            for (estimator, &performance) in estimators.iter_mut().zip(recent_performances.iter()) {
456                let current_weight = estimator.weight();
457                let target_weight = performance / total_performance;
458                let new_weight = current_weight + learning_rate * (target_weight - current_weight);
459                estimator.set_weight(new_weight.max(0.01)); // Minimum weight to avoid zero
460            }
461        }
462    }
463
464    /// Prune underperforming estimators from ensemble
465    pub fn prune_ensemble(
466        estimators: &mut Vec<Box<dyn EnsembleMember + Send + Sync>>,
467        performance_threshold: Float,
468        min_ensemble_size: usize,
469    ) {
470        if estimators.len() <= min_ensemble_size {
471            return;
472        }
473
474        estimators.retain(|estimator| estimator.performance() >= performance_threshold);
475
476        // Ensure minimum ensemble size
477        if estimators.len() < min_ensemble_size {
478            // This would require keeping the best performers, but for simplicity
479            // we just don't prune if it would violate the minimum size
480        }
481    }
482
483    /// Get ensemble statistics
484    pub fn get_ensemble_stats(
485        estimators: &[Box<dyn EnsembleMember + Send + Sync>],
486    ) -> EnsembleStats {
487        if estimators.is_empty() {
488            return EnsembleStats::default();
489        }
490
491        let weights: Vec<Float> = estimators.iter().map(|e| e.weight()).collect();
492        let performances: Vec<Float> = estimators.iter().map(|e| e.performance()).collect();
493        let confidences: Vec<Float> = estimators.iter().map(|e| e.confidence()).collect();
494
495        let mean_weight = weights.iter().sum::<Float>() / weights.len() as Float;
496        let mean_performance = performances.iter().sum::<Float>() / performances.len() as Float;
497        let mean_confidence = confidences.iter().sum::<Float>() / confidences.len() as Float;
498
499        let weight_variance = weights
500            .iter()
501            .map(|&w| (w - mean_weight).powi(2))
502            .sum::<Float>()
503            / weights.len() as Float;
504
505        EnsembleStats {
506            n_estimators: estimators.len(),
507            mean_weight,
508            mean_performance,
509            mean_confidence,
510            weight_variance,
511            total_complexity: estimators.iter().map(|e| e.complexity()).sum(),
512        }
513    }
514}
515
516/// Statistics about an ensemble
517#[derive(Debug, Clone)]
518pub struct EnsembleStats {
519    pub n_estimators: usize,
520    pub mean_weight: Float,
521    pub mean_performance: Float,
522    pub mean_confidence: Float,
523    pub weight_variance: Float,
524    pub total_complexity: Float,
525}
526
527impl Default for EnsembleStats {
528    fn default() -> Self {
529        Self {
530            n_estimators: 0,
531            mean_weight: 0.0,
532            mean_performance: 0.0,
533            mean_confidence: 0.0,
534            weight_variance: 0.0,
535            total_complexity: 0.0,
536        }
537    }
538}