1use super::config::StackingConfig;
7use crate::simd_stacking;
8use scirs2_core::ndarray::{s, Array1, Array2};
9use sklears_core::{
10 error::{Result, SklearsError},
11 prelude::Predict,
12 traits::{Fit, Trained, Untrained},
13 types::Float,
14};
15use std::marker::PhantomData;
16
17#[derive(Debug)]
23pub struct SimpleStackingClassifier<State = Untrained> {
24 pub(crate) config: StackingConfig,
25 pub(crate) state: PhantomData<State>,
26 pub(crate) base_weights_: Option<Array2<Float>>, pub(crate) base_intercepts_: Option<Array1<Float>>, pub(crate) meta_weights_: Option<Array1<Float>>, pub(crate) meta_intercept_: Option<Float>, pub(crate) n_base_estimators_: Option<usize>,
32 pub(crate) classes_: Option<Array1<i32>>,
33 pub(crate) n_features_in_: Option<usize>,
34}
35
36impl SimpleStackingClassifier<Untrained> {
37 pub fn new(n_base_estimators: usize) -> Self {
39 Self {
40 config: StackingConfig::default(),
41 state: PhantomData,
42 base_weights_: None,
43 base_intercepts_: None,
44 meta_weights_: None,
45 meta_intercept_: None,
46 n_base_estimators_: Some(n_base_estimators),
47 classes_: None,
48 n_features_in_: None,
49 }
50 }
51
52 pub fn cv(mut self, cv: usize) -> Self {
54 self.config.cv = cv;
55 self
56 }
57
58 pub fn use_probabilities(mut self, use_probabilities: bool) -> Self {
60 self.config.use_probabilities = use_probabilities;
61 self
62 }
63
64 pub fn random_state(mut self, random_state: u64) -> Self {
66 self.config.random_state = Some(random_state);
67 self
68 }
69
70 pub fn passthrough(mut self, passthrough: bool) -> Self {
72 self.config.passthrough = passthrough;
73 self
74 }
75}
76
77impl Fit<Array2<Float>, Array1<i32>> for SimpleStackingClassifier<Untrained> {
78 type Fitted = SimpleStackingClassifier<Trained>;
79
80 fn fit(self, x: &Array2<Float>, y: &Array1<i32>) -> Result<Self::Fitted> {
81 if x.nrows() != y.len() {
82 return Err(SklearsError::ShapeMismatch {
83 expected: format!("{} samples", x.nrows()),
84 actual: format!("{} samples", y.len()),
85 });
86 }
87
88 let (n_samples, n_features) = x.dim();
89 let n_base_estimators = self.n_base_estimators_.unwrap();
90
91 if n_samples < 10 {
92 return Err(SklearsError::InvalidInput(
93 "Stacking requires at least 10 samples".to_string(),
94 ));
95 }
96
97 let mut classes: Vec<i32> = y.to_vec();
99 classes.sort_unstable();
100 classes.dedup();
101 let classes_array = Array1::from_vec(classes.clone());
102 let n_classes = classes.len();
103
104 if n_classes < 2 {
105 return Err(SklearsError::InvalidInput(
106 "Need at least 2 classes for classification".to_string(),
107 ));
108 }
109
110 let y_float: Array1<Float> = y.mapv(|v| v as Float);
112
113 let (base_weights, base_intercepts) = self.train_base_estimators(x, &y_float)?;
115
116 let meta_features =
118 self.generate_meta_features(x, &y_float, &base_weights, &base_intercepts)?;
119
120 let (meta_weights, meta_intercept) = self.train_meta_learner(&meta_features, &y_float)?;
122
123 Ok(SimpleStackingClassifier {
124 config: self.config,
125 state: PhantomData,
126 base_weights_: Some(base_weights),
127 base_intercepts_: Some(base_intercepts),
128 meta_weights_: Some(meta_weights),
129 meta_intercept_: Some(meta_intercept),
130 n_base_estimators_: self.n_base_estimators_,
131 classes_: Some(classes_array),
132 n_features_in_: Some(n_features),
133 })
134 }
135}
136
137impl SimpleStackingClassifier<Untrained> {
138 fn train_base_estimators(
140 &self,
141 x: &Array2<Float>,
142 y: &Array1<Float>,
143 ) -> Result<(Array2<Float>, Array1<Float>)> {
144 let (n_samples, n_features) = x.dim();
145 let n_base_estimators = self.n_base_estimators_.unwrap();
146
147 let mut base_weights = Array2::<Float>::zeros((n_base_estimators, n_features));
148 let mut base_intercepts = Array1::<Float>::zeros(n_base_estimators);
149
150 for i in 0..n_base_estimators {
152 let seed = self.config.random_state.unwrap_or(42) + i as u64;
154 let mut rng = scirs2_core::random::Random::seed(seed);
155
156 for j in 0..n_features {
158 base_weights[[i, j]] = (scirs2_core::random::Rng::gen::<f64>(&mut rng) - 0.5) * 2.0;
159 }
160
161 base_intercepts[i] = y.mean().unwrap_or(0.0);
163 }
164
165 Ok((base_weights, base_intercepts))
166 }
167
168 fn generate_meta_features(
170 &self,
171 x: &Array2<Float>,
172 y: &Array1<Float>,
173 base_weights: &Array2<Float>,
174 base_intercepts: &Array1<Float>,
175 ) -> Result<Array2<Float>> {
176 let (n_samples, _) = x.dim();
177 let n_base_estimators = base_weights.nrows();
178
179 let holdout_size = n_samples / self.config.cv;
181 let train_size = n_samples - holdout_size;
182
183 if train_size < 5 {
184 return Err(SklearsError::InvalidInput(
185 "Insufficient samples for cross-validation".to_string(),
186 ));
187 }
188
189 let mut meta_features = Array2::<Float>::zeros((n_samples, n_base_estimators));
191 for i in 0..n_base_estimators {
192 let weights = base_weights.row(i);
193 let intercept = base_intercepts[i];
194 for j in 0..n_samples {
195 let x_sample = x.row(j);
196 let prediction = self.predict_linear(&weights, intercept, &x_sample);
197 meta_features[[j, i]] = prediction;
198 }
199 }
200
201 Ok(meta_features)
202 }
203
204 fn train_meta_learner(
206 &self,
207 meta_features: &Array2<Float>,
208 y: &Array1<Float>,
209 ) -> Result<(Array1<Float>, Float)> {
210 let (n_samples, n_meta_features) = meta_features.dim();
211
212 let mut x_with_intercept = Array2::<Float>::ones((n_samples, n_meta_features + 1));
214 x_with_intercept
215 .slice_mut(s![.., ..n_meta_features])
216 .assign(meta_features);
217
218 let mut xtx = Array2::<Float>::zeros((n_meta_features + 1, n_meta_features + 1));
220 for i in 0..(n_meta_features + 1) {
221 for j in 0..(n_meta_features + 1) {
222 for k in 0..n_samples {
223 xtx[[i, j]] += x_with_intercept[[k, i]] * x_with_intercept[[k, j]];
224 }
225 }
226 xtx[[i, i]] += 0.001;
228 }
229
230 let mut xty = Array1::<Float>::zeros(n_meta_features + 1);
231 for i in 0..(n_meta_features + 1) {
232 for j in 0..n_samples {
233 xty[i] += x_with_intercept[[j, i]] * y[j];
234 }
235 }
236
237 let params = self.solve_linear_system(&xtx, &xty)?;
239
240 let intercept = params[n_meta_features];
241 let weights = params.slice(s![..n_meta_features]).to_owned();
242
243 Ok((weights, intercept))
244 }
245
246 fn solve_linear_system(&self, a: &Array2<Float>, b: &Array1<Float>) -> Result<Array1<Float>> {
248 let n = a.nrows();
249 if n != a.ncols() || n != b.len() {
250 return Err(SklearsError::InvalidInput(
251 "Matrix dimensions don't match".to_string(),
252 ));
253 }
254
255 let mut aug = Array2::<Float>::zeros((n, n + 1));
257 for i in 0..n {
258 for j in 0..n {
259 aug[[i, j]] = a[[i, j]];
260 }
261 aug[[i, n]] = b[i];
262 }
263
264 for i in 0..n {
266 let mut max_row = i;
268 for k in (i + 1)..n {
269 if aug[[k, i]].abs() > aug[[max_row, i]].abs() {
270 max_row = k;
271 }
272 }
273
274 if max_row != i {
276 for j in 0..(n + 1) {
277 let temp = aug[[i, j]];
278 aug[[i, j]] = aug[[max_row, j]];
279 aug[[max_row, j]] = temp;
280 }
281 }
282
283 if aug[[i, i]].abs() < 1e-10 {
285 return Err(SklearsError::NumericalError(
286 "Singular matrix in linear system".to_string(),
287 ));
288 }
289
290 for k in (i + 1)..n {
292 let factor = aug[[k, i]] / aug[[i, i]];
293 for j in i..(n + 1) {
294 aug[[k, j]] -= factor * aug[[i, j]];
295 }
296 }
297 }
298
299 let mut x = Array1::<Float>::zeros(n);
301 for i in (0..n).rev() {
302 x[i] = aug[[i, n]];
303 for j in (i + 1)..n {
304 x[i] -= aug[[i, j]] * x[j];
305 }
306 x[i] /= aug[[i, i]];
307 }
308
309 Ok(x)
310 }
311
312 fn predict_linear(
314 &self,
315 weights: &scirs2_core::ndarray::ArrayView1<Float>,
316 intercept: Float,
317 x: &scirs2_core::ndarray::ArrayView1<Float>,
318 ) -> Float {
319 simd_stacking::simd_linear_prediction(x, weights, intercept)
321 }
322}
323
324impl SimpleStackingClassifier<Trained> {
325 fn predict_linear(
327 &self,
328 weights: &scirs2_core::ndarray::ArrayView1<Float>,
329 intercept: Float,
330 x: &scirs2_core::ndarray::ArrayView1<Float>,
331 ) -> Float {
332 simd_stacking::simd_linear_prediction(x, weights, intercept)
334 }
335}
336
337impl Predict<Array2<Float>, Array1<i32>> for SimpleStackingClassifier<Trained> {
338 fn predict(&self, x: &Array2<Float>) -> Result<Array1<i32>> {
339 if x.ncols() != self.n_features_in_.unwrap() {
340 return Err(SklearsError::FeatureMismatch {
341 expected: self.n_features_in_.unwrap(),
342 actual: x.ncols(),
343 });
344 }
345
346 let n_samples = x.nrows();
347 let n_base_estimators = self.n_base_estimators_.unwrap();
348
349 let base_weights = self.base_weights_.as_ref().unwrap();
350 let base_intercepts = self.base_intercepts_.as_ref().unwrap();
351 let meta_weights = self.meta_weights_.as_ref().unwrap();
352 let meta_intercept = self.meta_intercept_.unwrap();
353 let classes = self.classes_.as_ref().unwrap();
354
355 let meta_features = simd_stacking::simd_generate_meta_features(
357 &x.view(),
358 &base_weights.view(),
359 &base_intercepts.view(),
360 )
361 .unwrap_or_else(|_| {
362 let mut meta_features = Array2::<Float>::zeros((n_samples, n_base_estimators));
364 for i in 0..n_base_estimators {
365 let weights = base_weights.row(i);
366 let intercept = base_intercepts[i];
367 for j in 0..n_samples {
368 let x_sample = x.row(j);
369 let prediction = self.predict_linear(&weights, intercept, &x_sample);
370 meta_features[[j, i]] = prediction;
371 }
372 }
373 meta_features
374 });
375
376 let raw_predictions = simd_stacking::simd_aggregate_predictions(
378 &meta_features.view(),
379 &meta_weights.view(),
380 meta_intercept,
381 )
382 .unwrap_or_else(|_| {
383 let mut predictions = Array1::<Float>::zeros(n_samples);
385 for i in 0..n_samples {
386 let meta_sample = meta_features.row(i);
387 predictions[i] = meta_weights.dot(&meta_sample) + meta_intercept;
388 }
389 predictions
390 });
391
392 let mut predictions = Array1::<i32>::zeros(n_samples);
393
394 for i in 0..n_samples {
395 let raw_prediction = raw_predictions[i];
396
397 let class_pred = if raw_prediction >= 0.5 {
399 classes[classes.len() - 1] } else {
401 classes[0] };
403
404 predictions[i] = class_pred;
405 }
406
407 Ok(predictions)
408 }
409}
410
411impl SimpleStackingClassifier<Trained> {
412 pub fn classes(&self) -> &Array1<i32> {
414 self.classes_.as_ref().unwrap()
415 }
416
417 pub fn n_features_in(&self) -> usize {
419 self.n_features_in_.unwrap()
420 }
421
422 pub fn n_base_estimators(&self) -> usize {
424 self.n_base_estimators_.unwrap()
425 }
426
427 pub fn base_weights(&self) -> &Array2<Float> {
429 self.base_weights_.as_ref().unwrap()
430 }
431
432 pub fn base_intercepts(&self) -> &Array1<Float> {
434 self.base_intercepts_.as_ref().unwrap()
435 }
436
437 pub fn meta_weights(&self) -> &Array1<Float> {
439 self.meta_weights_.as_ref().unwrap()
440 }
441
442 pub fn meta_intercept(&self) -> Float {
444 self.meta_intercept_.unwrap()
445 }
446}
447
448pub use SimpleStackingClassifier as StackingClassifier;
450
451#[allow(non_snake_case)]
452#[cfg(test)]
453mod tests {
454 use super::*;
455 use scirs2_core::ndarray::array;
456
457 #[test]
458 fn test_stacking_creation() {
459 let stacking = StackingClassifier::new(3)
460 .cv(5)
461 .random_state(42)
462 .passthrough(true);
463
464 assert_eq!(stacking.config.cv, 5);
465 assert_eq!(stacking.config.random_state, Some(42));
466 assert_eq!(stacking.config.passthrough, true);
467 assert_eq!(stacking.n_base_estimators_.unwrap(), 3);
468 }
469
470 #[test]
471 fn test_stacking_fit_predict() {
472 let x = array![
473 [1.0, 2.0],
474 [3.0, 4.0],
475 [5.0, 6.0],
476 [7.0, 8.0],
477 [9.0, 10.0],
478 [11.0, 12.0],
479 [13.0, 14.0],
480 [15.0, 16.0],
481 [17.0, 18.0],
482 [19.0, 20.0],
483 [21.0, 22.0],
484 [23.0, 24.0]
485 ];
486 let y = array![0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1];
487
488 let stacking = StackingClassifier::new(2);
489 let fitted_model = stacking.fit(&x, &y).unwrap();
490
491 assert_eq!(fitted_model.n_features_in(), 2);
492 assert_eq!(fitted_model.classes().len(), 2);
493
494 let predictions = fitted_model.predict(&x).unwrap();
495 assert_eq!(predictions.len(), 12);
496 }
497
498 #[test]
499 fn test_shape_mismatch() {
500 let x = array![[1.0, 2.0], [3.0, 4.0]];
501 let y = array![0]; let stacking = StackingClassifier::new(1);
504 let result = stacking.fit(&x, &y);
505
506 assert!(result.is_err());
507 assert!(result.unwrap_err().to_string().contains("Shape mismatch"));
508 }
509
510 #[test]
511 fn test_feature_mismatch() {
512 let x_train = array![
513 [1.0, 2.0],
514 [3.0, 4.0],
515 [5.0, 6.0],
516 [7.0, 8.0],
517 [9.0, 10.0],
518 [11.0, 12.0],
519 [13.0, 14.0],
520 [15.0, 16.0],
521 [17.0, 18.0],
522 [19.0, 20.0],
523 [21.0, 22.0],
524 [23.0, 24.0]
525 ];
526 let y_train = array![0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1];
527 let x_test = array![[1.0, 2.0, 3.0]]; let stacking = StackingClassifier::new(1);
530 let fitted_model = stacking.fit(&x_train, &y_train).unwrap();
531 let result = fitted_model.predict(&x_test);
532
533 assert!(result.is_err());
534 assert!(result.unwrap_err().to_string().contains("Feature"));
535 }
536}