1use super::config::MetaLearningStrategy;
7use scirs2_core::ndarray::{s, Array1, Array2, ArrayView1};
8use sklears_core::{
9 error::{Result, SklearsError},
10 types::Float,
11};
12
13#[derive(Debug, Clone)]
15pub struct MetaLearner {
16 pub strategy: MetaLearningStrategy,
18 pub weights: Option<Array1<Float>>,
20 pub intercept: Option<Float>,
22 pub n_features: Option<usize>,
24}
25
26impl MetaLearner {
27 pub fn new(strategy: MetaLearningStrategy) -> Self {
29 Self {
30 strategy,
31 weights: None,
32 intercept: None,
33 n_features: None,
34 }
35 }
36
37 pub fn fit(&mut self, meta_features: &Array2<Float>, targets: &Array1<Float>) -> Result<()> {
39 if meta_features.nrows() != targets.len() {
40 return Err(SklearsError::ShapeMismatch {
41 expected: format!("{} samples", meta_features.nrows()),
42 actual: format!("{} samples", targets.len()),
43 });
44 }
45
46 let n_features = meta_features.ncols();
47 self.n_features = Some(n_features);
48
49 match self.strategy {
50 MetaLearningStrategy::LinearRegression => {
51 let (weights, intercept) = self.fit_linear_regression(meta_features, targets)?;
52 self.weights = Some(weights);
53 self.intercept = Some(intercept);
54 }
55 MetaLearningStrategy::Ridge(alpha) => {
56 let (weights, intercept) =
57 self.fit_ridge_regression(meta_features, targets, alpha)?;
58 self.weights = Some(weights);
59 self.intercept = Some(intercept);
60 }
61 MetaLearningStrategy::Lasso(alpha) => {
62 let (weights, intercept) =
63 self.fit_lasso_regression(meta_features, targets, alpha)?;
64 self.weights = Some(weights);
65 self.intercept = Some(intercept);
66 }
67 MetaLearningStrategy::ElasticNet(alpha, l1_ratio) => {
68 let (weights, intercept) =
69 self.fit_elastic_net(meta_features, targets, alpha, l1_ratio)?;
70 self.weights = Some(weights);
71 self.intercept = Some(intercept);
72 }
73 MetaLearningStrategy::LogisticRegression => {
74 let (weights, intercept) = self.fit_logistic_regression(meta_features, targets)?;
75 self.weights = Some(weights);
76 self.intercept = Some(intercept);
77 }
78 MetaLearningStrategy::BayesianAveraging => {
79 self.weights = Some(Array1::from_elem(n_features, 1.0 / n_features as Float));
81 self.intercept = Some(0.0);
82 }
83 _ => {
84 let (weights, intercept) = self.fit_linear_regression(meta_features, targets)?;
86 self.weights = Some(weights);
87 self.intercept = Some(intercept);
88 }
89 }
90
91 Ok(())
92 }
93
94 pub fn predict(&self, meta_features: &Array2<Float>) -> Result<Array1<Float>> {
96 let weights = self
97 .weights
98 .as_ref()
99 .ok_or_else(|| SklearsError::NotFitted {
100 operation: "predict".to_string(),
101 })?;
102 let intercept = self.intercept.ok_or_else(|| SklearsError::NotFitted {
103 operation: "predict".to_string(),
104 })?;
105
106 if meta_features.ncols() != self.n_features.unwrap() {
107 return Err(SklearsError::FeatureMismatch {
108 expected: self.n_features.unwrap(),
109 actual: meta_features.ncols(),
110 });
111 }
112
113 let n_samples = meta_features.nrows();
114 let mut predictions = Array1::zeros(n_samples);
115
116 for i in 0..n_samples {
117 let sample = meta_features.row(i);
118 predictions[i] = sample.dot(weights) + intercept;
119 }
120
121 Ok(predictions)
122 }
123
124 fn fit_linear_regression(
126 &self,
127 x: &Array2<Float>,
128 y: &Array1<Float>,
129 ) -> Result<(Array1<Float>, Float)> {
130 let (n_samples, n_features) = x.dim();
131
132 let mut x_aug = Array2::ones((n_samples, n_features + 1));
134 x_aug.slice_mut(s![.., ..n_features]).assign(x);
135
136 let xtx = x_aug.t().dot(&x_aug);
138 let xty = x_aug.t().dot(y);
139
140 let params = self.solve_linear_system(&xtx, &xty)?;
141 let intercept = params[n_features];
142 let weights = params.slice(s![..n_features]).to_owned();
143
144 Ok((weights, intercept))
145 }
146
147 fn fit_ridge_regression(
149 &self,
150 x: &Array2<Float>,
151 y: &Array1<Float>,
152 alpha: Float,
153 ) -> Result<(Array1<Float>, Float)> {
154 let (n_samples, n_features) = x.dim();
155
156 let mut x_aug = Array2::ones((n_samples, n_features + 1));
158 x_aug.slice_mut(s![.., ..n_features]).assign(x);
159
160 let mut xtx = x_aug.t().dot(&x_aug);
162
163 for i in 0..n_features {
165 xtx[[i, i]] += alpha;
166 }
167
168 let xty = x_aug.t().dot(y);
169 let params = self.solve_linear_system(&xtx, &xty)?;
170
171 let intercept = params[n_features];
172 let weights = params.slice(s![..n_features]).to_owned();
173
174 Ok((weights, intercept))
175 }
176
177 fn fit_lasso_regression(
179 &self,
180 x: &Array2<Float>,
181 y: &Array1<Float>,
182 alpha: Float,
183 ) -> Result<(Array1<Float>, Float)> {
184 let (n_samples, n_features) = x.dim();
186 let mut weights = Array1::zeros(n_features);
187 let mut intercept = y.mean().unwrap_or(0.0);
188
189 for _iter in 0..100 {
191 for j in 0..n_features {
192 let mut residual = 0.0;
193 for i in 0..n_samples {
194 let mut prediction = intercept;
195 for k in 0..n_features {
196 if k != j {
197 prediction += weights[k] * x[[i, k]];
198 }
199 }
200 residual += x[[i, j]] * (y[i] - prediction);
201 }
202
203 let threshold = alpha * n_samples as Float;
205 if residual > threshold {
206 weights[j] = (residual - threshold) / n_samples as Float;
207 } else if residual < -threshold {
208 weights[j] = (residual + threshold) / n_samples as Float;
209 } else {
210 weights[j] = 0.0;
211 }
212 }
213
214 let mut prediction_sum = 0.0;
216 for i in 0..n_samples {
217 prediction_sum += weights.dot(&x.row(i));
218 }
219 intercept = (y.sum() - prediction_sum) / n_samples as Float;
220 }
221
222 Ok((weights, intercept))
223 }
224
225 fn fit_elastic_net(
227 &self,
228 x: &Array2<Float>,
229 y: &Array1<Float>,
230 alpha: Float,
231 l1_ratio: Float,
232 ) -> Result<(Array1<Float>, Float)> {
233 let l1_alpha = alpha * l1_ratio;
234 let l2_alpha = alpha * (1.0 - l1_ratio);
235
236 let (n_samples, n_features) = x.dim();
238 let mut weights = Array1::zeros(n_features);
239 let mut intercept = y.mean().unwrap_or(0.0);
240
241 for _iter in 0..100 {
242 for j in 0..n_features {
243 let mut residual = 0.0;
244 let mut x_squared_sum = 0.0;
245
246 for i in 0..n_samples {
247 let mut prediction = intercept;
248 for k in 0..n_features {
249 if k != j {
250 prediction += weights[k] * x[[i, k]];
251 }
252 }
253 residual += x[[i, j]] * (y[i] - prediction);
254 x_squared_sum += x[[i, j]] * x[[i, j]];
255 }
256
257 let threshold = l1_alpha * n_samples as Float;
259 let denominator = x_squared_sum + l2_alpha * n_samples as Float;
260
261 if residual > threshold {
262 weights[j] = (residual - threshold) / denominator;
263 } else if residual < -threshold {
264 weights[j] = (residual + threshold) / denominator;
265 } else {
266 weights[j] = 0.0;
267 }
268 }
269
270 let mut prediction_sum = 0.0;
272 for i in 0..n_samples {
273 prediction_sum += weights.dot(&x.row(i));
274 }
275 intercept = (y.sum() - prediction_sum) / n_samples as Float;
276 }
277
278 Ok((weights, intercept))
279 }
280
281 fn fit_logistic_regression(
283 &self,
284 x: &Array2<Float>,
285 y: &Array1<Float>,
286 ) -> Result<(Array1<Float>, Float)> {
287 let (n_samples, n_features) = x.dim();
288 let mut weights = Array1::zeros(n_features);
289 let mut intercept = 0.0;
290 let learning_rate = 0.01;
291
292 for _iter in 0..1000 {
294 let mut weight_gradients = Array1::<Float>::zeros(n_features);
295 let mut intercept_gradient = 0.0;
296
297 for i in 0..n_samples {
298 let z = weights.dot(&x.row(i)) + intercept;
299 let prediction = self.sigmoid(z);
300 let error = y[i] - prediction;
301
302 for j in 0..n_features {
303 weight_gradients[j] += error * x[[i, j]];
304 }
305 intercept_gradient += error;
306 }
307
308 for j in 0..n_features {
310 weights[j] += learning_rate * weight_gradients[j] / n_samples as Float;
311 }
312 intercept += learning_rate * intercept_gradient / n_samples as Float;
313 }
314
315 Ok((weights, intercept))
316 }
317
318 fn sigmoid(&self, z: Float) -> Float {
320 1.0 / (1.0 + (-z).exp())
321 }
322
323 fn solve_linear_system(&self, a: &Array2<Float>, b: &Array1<Float>) -> Result<Array1<Float>> {
325 let n = a.nrows();
326 if n != a.ncols() || n != b.len() {
327 return Err(SklearsError::InvalidInput(
328 "Matrix dimensions don't match".to_string(),
329 ));
330 }
331
332 let mut aug = Array2::zeros((n, n + 1));
334 for i in 0..n {
335 for j in 0..n {
336 aug[[i, j]] = a[[i, j]];
337 }
338 aug[[i, n]] = b[i];
339 }
340
341 for i in 0..n {
343 let mut max_row = i;
345 for k in (i + 1)..n {
346 if aug[[k, i]].abs() > aug[[max_row, i]].abs() {
347 max_row = k;
348 }
349 }
350
351 if max_row != i {
353 for j in 0..(n + 1) {
354 let temp = aug[[i, j]];
355 aug[[i, j]] = aug[[max_row, j]];
356 aug[[max_row, j]] = temp;
357 }
358 }
359
360 if aug[[i, i]].abs() < 1e-12 {
362 return Err(SklearsError::NumericalError(
363 "Singular matrix in linear system".to_string(),
364 ));
365 }
366
367 for k in (i + 1)..n {
369 let factor = aug[[k, i]] / aug[[i, i]];
370 for j in i..(n + 1) {
371 aug[[k, j]] -= factor * aug[[i, j]];
372 }
373 }
374 }
375
376 let mut x = Array1::zeros(n);
378 for i in (0..n).rev() {
379 x[i] = aug[[i, n]];
380 for j in (i + 1)..n {
381 x[i] -= aug[[i, j]] * x[j];
382 }
383 x[i] /= aug[[i, i]];
384 }
385
386 Ok(x)
387 }
388}
389
390pub fn calculate_diversity(predictions: &Array2<Float>) -> Result<Float> {
392 let (n_samples, n_estimators) = predictions.dim();
393
394 if n_estimators < 2 {
395 return Ok(0.0);
396 }
397
398 let mut total_correlation = 0.0;
399 let mut count = 0;
400
401 for i in 0..n_estimators {
402 for j in (i + 1)..n_estimators {
403 let pred_i = predictions.column(i);
404 let pred_j = predictions.column(j);
405
406 let correlation = calculate_correlation(&pred_i, &pred_j)?;
407 total_correlation += correlation.abs();
408 count += 1;
409 }
410 }
411
412 if count == 0 {
413 Ok(0.0)
414 } else {
415 Ok(1.0 - total_correlation / count as Float)
417 }
418}
419
420pub fn calculate_correlation(x: &ArrayView1<Float>, y: &ArrayView1<Float>) -> Result<Float> {
422 if x.len() != y.len() {
423 return Err(SklearsError::InvalidInput(
424 "Vectors must have the same length".to_string(),
425 ));
426 }
427
428 let n = x.len() as Float;
429 if n < 2.0 {
430 return Ok(0.0);
431 }
432
433 let mean_x = x.sum() / n;
434 let mean_y = y.sum() / n;
435
436 let mut numerator = 0.0;
437 let mut sum_sq_x = 0.0;
438 let mut sum_sq_y = 0.0;
439
440 for i in 0..x.len() {
441 let dx = x[i] - mean_x;
442 let dy = y[i] - mean_y;
443
444 numerator += dx * dy;
445 sum_sq_x += dx * dx;
446 sum_sq_y += dy * dy;
447 }
448
449 let denominator = (sum_sq_x * sum_sq_y).sqrt();
450
451 if denominator < 1e-12 {
452 Ok(0.0)
453 } else {
454 Ok(numerator / denominator)
455 }
456}
457
458#[allow(non_snake_case)]
459#[cfg(test)]
460mod tests {
461 use super::*;
462 use scirs2_core::ndarray::array;
463
464 #[test]
465 fn test_meta_learner_creation() {
466 let meta_learner = MetaLearner::new(MetaLearningStrategy::LinearRegression);
467 assert!(matches!(
468 meta_learner.strategy,
469 MetaLearningStrategy::LinearRegression
470 ));
471 assert!(meta_learner.weights.is_none());
472 assert!(meta_learner.intercept.is_none());
473 }
474
475 #[test]
476 fn test_linear_regression_fit_predict() {
477 let mut meta_learner = MetaLearner::new(MetaLearningStrategy::LinearRegression);
478
479 let meta_features = array![
481 [1.0, 0.5],
482 [2.0, 1.0],
483 [0.5, 2.0],
484 [1.5, 0.8],
485 [0.3, 1.2],
486 [2.1, 0.4]
487 ];
488 let targets = array![1.2, 2.1, 1.8, 1.6, 1.1, 1.9];
489
490 meta_learner.fit(&meta_features, &targets).unwrap();
491
492 assert!(meta_learner.weights.is_some());
493 assert!(meta_learner.intercept.is_some());
494
495 let predictions = meta_learner.predict(&meta_features).unwrap();
496 assert_eq!(predictions.len(), 6);
497 }
498
499 #[test]
500 fn test_ridge_regression() {
501 let mut meta_learner = MetaLearner::new(MetaLearningStrategy::Ridge(0.1));
502
503 let meta_features = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0], [4.0, 5.0]];
504 let targets = array![3.0, 5.0, 7.0, 9.0];
505
506 meta_learner.fit(&meta_features, &targets).unwrap();
507 let predictions = meta_learner.predict(&meta_features).unwrap();
508 assert_eq!(predictions.len(), 4);
509 }
510
511 #[test]
512 fn test_bayesian_averaging() {
513 let mut meta_learner = MetaLearner::new(MetaLearningStrategy::BayesianAveraging);
514
515 let meta_features = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]];
516 let targets = array![6.0, 15.0];
517
518 meta_learner.fit(&meta_features, &targets).unwrap();
519
520 let weights = meta_learner.weights.as_ref().unwrap();
521 assert_eq!(weights.len(), 3);
522 assert!((weights.sum() - 1.0).abs() < 1e-10);
523 }
524
525 #[test]
526 fn test_diversity_calculation() {
527 let predictions = array![
528 [1.0, 2.0, 1.5],
529 [2.0, 3.0, 2.2],
530 [3.0, 4.0, 3.8],
531 [4.0, 5.0, 4.1]
532 ];
533
534 let diversity = calculate_diversity(&predictions).unwrap();
535 assert!(diversity >= 0.0 && diversity <= 1.0);
536 }
537
538 #[test]
539 fn test_correlation_calculation() {
540 let x = array![1.0, 2.0, 3.0, 4.0];
541 let y = array![2.0, 4.0, 6.0, 8.0]; let correlation = calculate_correlation(&x.view(), &y.view()).unwrap();
544 assert!((correlation - 1.0).abs() < 1e-10);
545 }
546
547 #[test]
548 fn test_shape_mismatch_error() {
549 let mut meta_learner = MetaLearner::new(MetaLearningStrategy::LinearRegression);
550
551 let meta_features = array![[1.0, 2.0], [3.0, 4.0]];
552 let targets = array![3.0]; let result = meta_learner.fit(&meta_features, &targets);
555 assert!(result.is_err());
556 }
557
558 #[test]
559 fn test_not_fitted_error() {
560 let meta_learner = MetaLearner::new(MetaLearningStrategy::LinearRegression);
561 let meta_features = array![[1.0, 2.0]];
562
563 let result = meta_learner.predict(&meta_features);
564 assert!(result.is_err());
565 assert!(result.unwrap_err().to_string().contains("not fitted"));
566 }
567}