sklears_ensemble/stacking/
meta_learning.rs

1//! Meta-learning strategies and utilities for stacking ensembles
2//!
3//! This module provides implementations of various meta-learning algorithms
4//! used to combine predictions from base estimators in stacking ensembles.
5
6use super::config::MetaLearningStrategy;
7use scirs2_core::ndarray::{s, Array1, Array2, ArrayView1};
8use sklears_core::{
9    error::{Result, SklearsError},
10    types::Float,
11};
12
13/// Meta-learner implementation that combines base estimator predictions
14#[derive(Debug, Clone)]
15pub struct MetaLearner {
16    /// The strategy used for meta-learning
17    pub strategy: MetaLearningStrategy,
18    /// Learned weights (if applicable)
19    pub weights: Option<Array1<Float>>,
20    /// Learned intercept (if applicable)
21    pub intercept: Option<Float>,
22    /// Number of features expected during prediction
23    pub n_features: Option<usize>,
24}
25
26impl MetaLearner {
27    /// Create a new meta-learner with the specified strategy
28    pub fn new(strategy: MetaLearningStrategy) -> Self {
29        Self {
30            strategy,
31            weights: None,
32            intercept: None,
33            n_features: None,
34        }
35    }
36
37    /// Train the meta-learner on meta-features
38    pub fn fit(&mut self, meta_features: &Array2<Float>, targets: &Array1<Float>) -> Result<()> {
39        if meta_features.nrows() != targets.len() {
40            return Err(SklearsError::ShapeMismatch {
41                expected: format!("{} samples", meta_features.nrows()),
42                actual: format!("{} samples", targets.len()),
43            });
44        }
45
46        let n_features = meta_features.ncols();
47        self.n_features = Some(n_features);
48
49        match self.strategy {
50            MetaLearningStrategy::LinearRegression => {
51                let (weights, intercept) = self.fit_linear_regression(meta_features, targets)?;
52                self.weights = Some(weights);
53                self.intercept = Some(intercept);
54            }
55            MetaLearningStrategy::Ridge(alpha) => {
56                let (weights, intercept) =
57                    self.fit_ridge_regression(meta_features, targets, alpha)?;
58                self.weights = Some(weights);
59                self.intercept = Some(intercept);
60            }
61            MetaLearningStrategy::Lasso(alpha) => {
62                let (weights, intercept) =
63                    self.fit_lasso_regression(meta_features, targets, alpha)?;
64                self.weights = Some(weights);
65                self.intercept = Some(intercept);
66            }
67            MetaLearningStrategy::ElasticNet(alpha, l1_ratio) => {
68                let (weights, intercept) =
69                    self.fit_elastic_net(meta_features, targets, alpha, l1_ratio)?;
70                self.weights = Some(weights);
71                self.intercept = Some(intercept);
72            }
73            MetaLearningStrategy::LogisticRegression => {
74                let (weights, intercept) = self.fit_logistic_regression(meta_features, targets)?;
75                self.weights = Some(weights);
76                self.intercept = Some(intercept);
77            }
78            MetaLearningStrategy::BayesianAveraging => {
79                // Bayesian averaging uses uniform weights
80                self.weights = Some(Array1::from_elem(n_features, 1.0 / n_features as Float));
81                self.intercept = Some(0.0);
82            }
83            _ => {
84                // For other strategies, use linear regression as fallback
85                let (weights, intercept) = self.fit_linear_regression(meta_features, targets)?;
86                self.weights = Some(weights);
87                self.intercept = Some(intercept);
88            }
89        }
90
91        Ok(())
92    }
93
94    /// Make predictions using the trained meta-learner
95    pub fn predict(&self, meta_features: &Array2<Float>) -> Result<Array1<Float>> {
96        let weights = self
97            .weights
98            .as_ref()
99            .ok_or_else(|| SklearsError::NotFitted {
100                operation: "predict".to_string(),
101            })?;
102        let intercept = self.intercept.ok_or_else(|| SklearsError::NotFitted {
103            operation: "predict".to_string(),
104        })?;
105
106        if meta_features.ncols() != self.n_features.unwrap() {
107            return Err(SklearsError::FeatureMismatch {
108                expected: self.n_features.unwrap(),
109                actual: meta_features.ncols(),
110            });
111        }
112
113        let n_samples = meta_features.nrows();
114        let mut predictions = Array1::zeros(n_samples);
115
116        for i in 0..n_samples {
117            let sample = meta_features.row(i);
118            predictions[i] = sample.dot(weights) + intercept;
119        }
120
121        Ok(predictions)
122    }
123
124    /// Fit linear regression using normal equations
125    fn fit_linear_regression(
126        &self,
127        x: &Array2<Float>,
128        y: &Array1<Float>,
129    ) -> Result<(Array1<Float>, Float)> {
130        let (n_samples, n_features) = x.dim();
131
132        // Create augmented feature matrix with intercept column
133        let mut x_aug = Array2::ones((n_samples, n_features + 1));
134        x_aug.slice_mut(s![.., ..n_features]).assign(x);
135
136        // Solve normal equations: (X^T X)^(-1) X^T y
137        let xtx = x_aug.t().dot(&x_aug);
138        let xty = x_aug.t().dot(y);
139
140        let params = self.solve_linear_system(&xtx, &xty)?;
141        let intercept = params[n_features];
142        let weights = params.slice(s![..n_features]).to_owned();
143
144        Ok((weights, intercept))
145    }
146
147    /// Fit Ridge regression with L2 regularization
148    fn fit_ridge_regression(
149        &self,
150        x: &Array2<Float>,
151        y: &Array1<Float>,
152        alpha: Float,
153    ) -> Result<(Array1<Float>, Float)> {
154        let (n_samples, n_features) = x.dim();
155
156        // Create augmented feature matrix with intercept column
157        let mut x_aug = Array2::ones((n_samples, n_features + 1));
158        x_aug.slice_mut(s![.., ..n_features]).assign(x);
159
160        // Solve regularized normal equations: (X^T X + αI)^(-1) X^T y
161        let mut xtx = x_aug.t().dot(&x_aug);
162
163        // Add regularization to diagonal (except intercept term)
164        for i in 0..n_features {
165            xtx[[i, i]] += alpha;
166        }
167
168        let xty = x_aug.t().dot(y);
169        let params = self.solve_linear_system(&xtx, &xty)?;
170
171        let intercept = params[n_features];
172        let weights = params.slice(s![..n_features]).to_owned();
173
174        Ok((weights, intercept))
175    }
176
177    /// Fit Lasso regression with L1 regularization (simplified implementation)
178    fn fit_lasso_regression(
179        &self,
180        x: &Array2<Float>,
181        y: &Array1<Float>,
182        alpha: Float,
183    ) -> Result<(Array1<Float>, Float)> {
184        // For simplicity, use coordinate descent approximation with soft thresholding
185        let (n_samples, n_features) = x.dim();
186        let mut weights = Array1::zeros(n_features);
187        let mut intercept = y.mean().unwrap_or(0.0);
188
189        // Simple coordinate descent iterations
190        for _iter in 0..100 {
191            for j in 0..n_features {
192                let mut residual = 0.0;
193                for i in 0..n_samples {
194                    let mut prediction = intercept;
195                    for k in 0..n_features {
196                        if k != j {
197                            prediction += weights[k] * x[[i, k]];
198                        }
199                    }
200                    residual += x[[i, j]] * (y[i] - prediction);
201                }
202
203                // Soft thresholding
204                let threshold = alpha * n_samples as Float;
205                if residual > threshold {
206                    weights[j] = (residual - threshold) / n_samples as Float;
207                } else if residual < -threshold {
208                    weights[j] = (residual + threshold) / n_samples as Float;
209                } else {
210                    weights[j] = 0.0;
211                }
212            }
213
214            // Update intercept
215            let mut prediction_sum = 0.0;
216            for i in 0..n_samples {
217                prediction_sum += weights.dot(&x.row(i));
218            }
219            intercept = (y.sum() - prediction_sum) / n_samples as Float;
220        }
221
222        Ok((weights, intercept))
223    }
224
225    /// Fit Elastic Net regression with combined L1/L2 regularization
226    fn fit_elastic_net(
227        &self,
228        x: &Array2<Float>,
229        y: &Array1<Float>,
230        alpha: Float,
231        l1_ratio: Float,
232    ) -> Result<(Array1<Float>, Float)> {
233        let l1_alpha = alpha * l1_ratio;
234        let l2_alpha = alpha * (1.0 - l1_ratio);
235
236        // Simplified elastic net using coordinate descent
237        let (n_samples, n_features) = x.dim();
238        let mut weights = Array1::zeros(n_features);
239        let mut intercept = y.mean().unwrap_or(0.0);
240
241        for _iter in 0..100 {
242            for j in 0..n_features {
243                let mut residual = 0.0;
244                let mut x_squared_sum = 0.0;
245
246                for i in 0..n_samples {
247                    let mut prediction = intercept;
248                    for k in 0..n_features {
249                        if k != j {
250                            prediction += weights[k] * x[[i, k]];
251                        }
252                    }
253                    residual += x[[i, j]] * (y[i] - prediction);
254                    x_squared_sum += x[[i, j]] * x[[i, j]];
255                }
256
257                // Soft thresholding with L2 modification
258                let threshold = l1_alpha * n_samples as Float;
259                let denominator = x_squared_sum + l2_alpha * n_samples as Float;
260
261                if residual > threshold {
262                    weights[j] = (residual - threshold) / denominator;
263                } else if residual < -threshold {
264                    weights[j] = (residual + threshold) / denominator;
265                } else {
266                    weights[j] = 0.0;
267                }
268            }
269
270            // Update intercept
271            let mut prediction_sum = 0.0;
272            for i in 0..n_samples {
273                prediction_sum += weights.dot(&x.row(i));
274            }
275            intercept = (y.sum() - prediction_sum) / n_samples as Float;
276        }
277
278        Ok((weights, intercept))
279    }
280
281    /// Fit logistic regression (simplified implementation)
282    fn fit_logistic_regression(
283        &self,
284        x: &Array2<Float>,
285        y: &Array1<Float>,
286    ) -> Result<(Array1<Float>, Float)> {
287        let (n_samples, n_features) = x.dim();
288        let mut weights = Array1::zeros(n_features);
289        let mut intercept = 0.0;
290        let learning_rate = 0.01;
291
292        // Simple gradient descent
293        for _iter in 0..1000 {
294            let mut weight_gradients = Array1::<Float>::zeros(n_features);
295            let mut intercept_gradient = 0.0;
296
297            for i in 0..n_samples {
298                let z = weights.dot(&x.row(i)) + intercept;
299                let prediction = self.sigmoid(z);
300                let error = y[i] - prediction;
301
302                for j in 0..n_features {
303                    weight_gradients[j] += error * x[[i, j]];
304                }
305                intercept_gradient += error;
306            }
307
308            // Update parameters
309            for j in 0..n_features {
310                weights[j] += learning_rate * weight_gradients[j] / n_samples as Float;
311            }
312            intercept += learning_rate * intercept_gradient / n_samples as Float;
313        }
314
315        Ok((weights, intercept))
316    }
317
318    /// Sigmoid activation function
319    fn sigmoid(&self, z: Float) -> Float {
320        1.0 / (1.0 + (-z).exp())
321    }
322
323    /// Solve linear system using Gaussian elimination
324    fn solve_linear_system(&self, a: &Array2<Float>, b: &Array1<Float>) -> Result<Array1<Float>> {
325        let n = a.nrows();
326        if n != a.ncols() || n != b.len() {
327            return Err(SklearsError::InvalidInput(
328                "Matrix dimensions don't match".to_string(),
329            ));
330        }
331
332        // Create augmented matrix
333        let mut aug = Array2::zeros((n, n + 1));
334        for i in 0..n {
335            for j in 0..n {
336                aug[[i, j]] = a[[i, j]];
337            }
338            aug[[i, n]] = b[i];
339        }
340
341        // Forward elimination with partial pivoting
342        for i in 0..n {
343            // Find pivot
344            let mut max_row = i;
345            for k in (i + 1)..n {
346                if aug[[k, i]].abs() > aug[[max_row, i]].abs() {
347                    max_row = k;
348                }
349            }
350
351            // Swap rows
352            if max_row != i {
353                for j in 0..(n + 1) {
354                    let temp = aug[[i, j]];
355                    aug[[i, j]] = aug[[max_row, j]];
356                    aug[[max_row, j]] = temp;
357                }
358            }
359
360            // Check for singular matrix
361            if aug[[i, i]].abs() < 1e-12 {
362                return Err(SklearsError::NumericalError(
363                    "Singular matrix in linear system".to_string(),
364                ));
365            }
366
367            // Eliminate column
368            for k in (i + 1)..n {
369                let factor = aug[[k, i]] / aug[[i, i]];
370                for j in i..(n + 1) {
371                    aug[[k, j]] -= factor * aug[[i, j]];
372                }
373            }
374        }
375
376        // Back substitution
377        let mut x = Array1::zeros(n);
378        for i in (0..n).rev() {
379            x[i] = aug[[i, n]];
380            for j in (i + 1)..n {
381                x[i] -= aug[[i, j]] * x[j];
382            }
383            x[i] /= aug[[i, i]];
384        }
385
386        Ok(x)
387    }
388}
389
390/// Calculate ensemble diversity using pairwise correlation
391pub fn calculate_diversity(predictions: &Array2<Float>) -> Result<Float> {
392    let (n_samples, n_estimators) = predictions.dim();
393
394    if n_estimators < 2 {
395        return Ok(0.0);
396    }
397
398    let mut total_correlation = 0.0;
399    let mut count = 0;
400
401    for i in 0..n_estimators {
402        for j in (i + 1)..n_estimators {
403            let pred_i = predictions.column(i);
404            let pred_j = predictions.column(j);
405
406            let correlation = calculate_correlation(&pred_i, &pred_j)?;
407            total_correlation += correlation.abs();
408            count += 1;
409        }
410    }
411
412    if count == 0 {
413        Ok(0.0)
414    } else {
415        // Diversity is 1 - average absolute correlation
416        Ok(1.0 - total_correlation / count as Float)
417    }
418}
419
420/// Calculate Pearson correlation coefficient between two vectors
421pub fn calculate_correlation(x: &ArrayView1<Float>, y: &ArrayView1<Float>) -> Result<Float> {
422    if x.len() != y.len() {
423        return Err(SklearsError::InvalidInput(
424            "Vectors must have the same length".to_string(),
425        ));
426    }
427
428    let n = x.len() as Float;
429    if n < 2.0 {
430        return Ok(0.0);
431    }
432
433    let mean_x = x.sum() / n;
434    let mean_y = y.sum() / n;
435
436    let mut numerator = 0.0;
437    let mut sum_sq_x = 0.0;
438    let mut sum_sq_y = 0.0;
439
440    for i in 0..x.len() {
441        let dx = x[i] - mean_x;
442        let dy = y[i] - mean_y;
443
444        numerator += dx * dy;
445        sum_sq_x += dx * dx;
446        sum_sq_y += dy * dy;
447    }
448
449    let denominator = (sum_sq_x * sum_sq_y).sqrt();
450
451    if denominator < 1e-12 {
452        Ok(0.0)
453    } else {
454        Ok(numerator / denominator)
455    }
456}
457
458#[allow(non_snake_case)]
459#[cfg(test)]
460mod tests {
461    use super::*;
462    use scirs2_core::ndarray::array;
463
464    #[test]
465    fn test_meta_learner_creation() {
466        let meta_learner = MetaLearner::new(MetaLearningStrategy::LinearRegression);
467        assert!(matches!(
468            meta_learner.strategy,
469            MetaLearningStrategy::LinearRegression
470        ));
471        assert!(meta_learner.weights.is_none());
472        assert!(meta_learner.intercept.is_none());
473    }
474
475    #[test]
476    fn test_linear_regression_fit_predict() {
477        let mut meta_learner = MetaLearner::new(MetaLearningStrategy::LinearRegression);
478
479        // Create non-linearly dependent data to avoid singular matrix
480        let meta_features = array![
481            [1.0, 0.5],
482            [2.0, 1.0],
483            [0.5, 2.0],
484            [1.5, 0.8],
485            [0.3, 1.2],
486            [2.1, 0.4]
487        ];
488        let targets = array![1.2, 2.1, 1.8, 1.6, 1.1, 1.9];
489
490        meta_learner.fit(&meta_features, &targets).unwrap();
491
492        assert!(meta_learner.weights.is_some());
493        assert!(meta_learner.intercept.is_some());
494
495        let predictions = meta_learner.predict(&meta_features).unwrap();
496        assert_eq!(predictions.len(), 6);
497    }
498
499    #[test]
500    fn test_ridge_regression() {
501        let mut meta_learner = MetaLearner::new(MetaLearningStrategy::Ridge(0.1));
502
503        let meta_features = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0], [4.0, 5.0]];
504        let targets = array![3.0, 5.0, 7.0, 9.0];
505
506        meta_learner.fit(&meta_features, &targets).unwrap();
507        let predictions = meta_learner.predict(&meta_features).unwrap();
508        assert_eq!(predictions.len(), 4);
509    }
510
511    #[test]
512    fn test_bayesian_averaging() {
513        let mut meta_learner = MetaLearner::new(MetaLearningStrategy::BayesianAveraging);
514
515        let meta_features = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]];
516        let targets = array![6.0, 15.0];
517
518        meta_learner.fit(&meta_features, &targets).unwrap();
519
520        let weights = meta_learner.weights.as_ref().unwrap();
521        assert_eq!(weights.len(), 3);
522        assert!((weights.sum() - 1.0).abs() < 1e-10);
523    }
524
525    #[test]
526    fn test_diversity_calculation() {
527        let predictions = array![
528            [1.0, 2.0, 1.5],
529            [2.0, 3.0, 2.2],
530            [3.0, 4.0, 3.8],
531            [4.0, 5.0, 4.1]
532        ];
533
534        let diversity = calculate_diversity(&predictions).unwrap();
535        assert!(diversity >= 0.0 && diversity <= 1.0);
536    }
537
538    #[test]
539    fn test_correlation_calculation() {
540        let x = array![1.0, 2.0, 3.0, 4.0];
541        let y = array![2.0, 4.0, 6.0, 8.0]; // Perfect positive correlation
542
543        let correlation = calculate_correlation(&x.view(), &y.view()).unwrap();
544        assert!((correlation - 1.0).abs() < 1e-10);
545    }
546
547    #[test]
548    fn test_shape_mismatch_error() {
549        let mut meta_learner = MetaLearner::new(MetaLearningStrategy::LinearRegression);
550
551        let meta_features = array![[1.0, 2.0], [3.0, 4.0]];
552        let targets = array![3.0]; // Wrong length
553
554        let result = meta_learner.fit(&meta_features, &targets);
555        assert!(result.is_err());
556    }
557
558    #[test]
559    fn test_not_fitted_error() {
560        let meta_learner = MetaLearner::new(MetaLearningStrategy::LinearRegression);
561        let meta_features = array![[1.0, 2.0]];
562
563        let result = meta_learner.predict(&meta_features);
564        assert!(result.is_err());
565        assert!(result.unwrap_err().to_string().contains("not fitted"));
566    }
567}