1use crate::PipelinePredictor;
7use scirs2_core::ndarray::{s, Array1, Array2, ArrayView1, ArrayView2};
8use sklears_core::{
9 error::Result as SklResult,
10 prelude::SklearsError,
11 traits::{Estimator, Fit, Untrained},
12 types::Float,
13};
14
15#[derive(Debug, Clone)]
17pub enum SelectionStrategy {
18 KBest { k: usize },
20 Threshold { threshold: f64 },
22 AboveMedian,
24 LocalCompetence { k_neighbors: usize },
26}
27
28#[derive(Debug, Clone)]
30pub enum CompetenceEstimation {
31 LocalAccuracy,
33 DecisionBoundary,
35 Entropy,
37 Margin,
39}
40
41pub struct DynamicEnsembleSelector<S = Untrained> {
59 state: S,
60 estimators: Vec<(String, Box<dyn PipelinePredictor>)>,
61 selection_strategy: SelectionStrategy,
62 competence_estimation: CompetenceEstimation,
63 validation_split: f64,
64 n_jobs: Option<i32>,
65}
66
67pub struct DynamicEnsembleSelectorTrained {
69 fitted_estimators: Vec<(String, Box<dyn PipelinePredictor>)>,
70 validation_data: Array2<f64>,
71 validation_targets: Array1<f64>,
72 competence_scores: Vec<Vec<f64>>, n_features_in: usize,
74 feature_names_in: Option<Vec<String>>,
75}
76
77impl DynamicEnsembleSelector<Untrained> {
78 #[must_use]
80 pub fn new() -> Self {
81 Self {
82 state: Untrained,
83 estimators: Vec::new(),
84 selection_strategy: SelectionStrategy::KBest { k: 3 },
85 competence_estimation: CompetenceEstimation::LocalAccuracy,
86 validation_split: 0.2,
87 n_jobs: None,
88 }
89 }
90
91 #[must_use]
93 pub fn builder() -> DynamicEnsembleSelectorBuilder {
94 DynamicEnsembleSelectorBuilder::new()
95 }
96
97 pub fn add_estimator(&mut self, name: String, estimator: Box<dyn PipelinePredictor>) {
99 self.estimators.push((name, estimator));
100 }
101
102 #[must_use]
104 pub fn selection_strategy(mut self, strategy: SelectionStrategy) -> Self {
105 self.selection_strategy = strategy;
106 self
107 }
108
109 #[must_use]
111 pub fn competence_estimation(mut self, method: CompetenceEstimation) -> Self {
112 self.competence_estimation = method;
113 self
114 }
115
116 #[must_use]
118 pub fn validation_split(mut self, split: f64) -> Self {
119 self.validation_split = split.clamp(0.1, 0.5);
120 self
121 }
122
123 #[must_use]
125 pub fn n_jobs(mut self, n_jobs: Option<i32>) -> Self {
126 self.n_jobs = n_jobs;
127 self
128 }
129}
130
131impl Default for DynamicEnsembleSelector<Untrained> {
132 fn default() -> Self {
133 Self::new()
134 }
135}
136
137impl Estimator for DynamicEnsembleSelector<Untrained> {
138 type Config = ();
139 type Error = SklearsError;
140 type Float = Float;
141
142 fn config(&self) -> &Self::Config {
143 &()
144 }
145}
146
147impl Fit<ArrayView2<'_, Float>, Option<&ArrayView1<'_, Float>>>
148 for DynamicEnsembleSelector<Untrained>
149{
150 type Fitted = DynamicEnsembleSelector<DynamicEnsembleSelectorTrained>;
151
152 fn fit(
153 self,
154 x: &ArrayView2<'_, Float>,
155 y: &Option<&ArrayView1<'_, Float>>,
156 ) -> SklResult<Self::Fitted> {
157 if let Some(y_values) = y.as_ref() {
158 if self.estimators.is_empty() {
159 return Err(SklearsError::InvalidInput(
160 "At least one estimator must be provided".to_string(),
161 ));
162 }
163
164 let n_samples = x.nrows();
165 let validation_size = (n_samples as f64 * self.validation_split).max(1.0) as usize;
166 let train_size = n_samples - validation_size;
167
168 let x_train = x.slice(s![..train_size, ..]);
170 let y_train = y_values.slice(s![..train_size]);
171 let x_val = x.slice(s![train_size.., ..]);
172 let y_val = y_values.slice(s![train_size..]);
173
174 let mut fitted_estimators = Vec::new();
176 let estimators: Vec<(String, Box<dyn PipelinePredictor>)> = self
177 .estimators
178 .iter()
179 .map(|(name, estimator)| (name.clone(), estimator.clone_predictor()))
180 .collect();
181 for (name, mut estimator) in estimators {
182 estimator.fit(&x_train, &y_train)?;
183 fitted_estimators.push((name, estimator));
184 }
185
186 let competence_scores =
188 self.compute_competence_scores(&fitted_estimators, &x_val, &y_val)?;
189
190 Ok(DynamicEnsembleSelector {
191 state: DynamicEnsembleSelectorTrained {
192 fitted_estimators,
193 validation_data: x_val.mapv(|v| v),
194 validation_targets: y_val.mapv(|v| v),
195 competence_scores,
196 n_features_in: x.ncols(),
197 feature_names_in: None,
198 },
199 estimators: Vec::new(),
200 selection_strategy: self.selection_strategy,
201 competence_estimation: self.competence_estimation,
202 validation_split: self.validation_split,
203 n_jobs: self.n_jobs,
204 })
205 } else {
206 Err(SklearsError::InvalidInput(
207 "Target values required for fitting".to_string(),
208 ))
209 }
210 }
211}
212
213impl DynamicEnsembleSelector<Untrained> {
214 fn compute_competence_scores(
216 &self,
217 estimators: &[(String, Box<dyn PipelinePredictor>)],
218 x_val: &ArrayView2<'_, Float>,
219 y_val: &ArrayView1<'_, Float>,
220 ) -> SklResult<Vec<Vec<f64>>> {
221 let mut competence_scores = Vec::new();
222
223 for (_, estimator) in estimators {
224 let predictions = estimator.predict(x_val)?;
225 let scores = match self.competence_estimation {
226 CompetenceEstimation::LocalAccuracy => {
227 self.compute_local_accuracy(&predictions, y_val)?
228 }
229 CompetenceEstimation::DecisionBoundary => {
230 self.compute_decision_boundary_competence(&predictions, y_val)?
231 }
232 CompetenceEstimation::Entropy => self.compute_entropy_competence(&predictions)?,
233 CompetenceEstimation::Margin => {
234 self.compute_margin_competence(&predictions, y_val)?
235 }
236 };
237 competence_scores.push(scores);
238 }
239
240 Ok(competence_scores)
241 }
242
243 fn compute_local_accuracy(
245 &self,
246 predictions: &Array1<f64>,
247 y_true: &ArrayView1<'_, Float>,
248 ) -> SklResult<Vec<f64>> {
249 let mut scores = Vec::new();
250
251 for i in 0..predictions.len() {
252 let pred = predictions[i];
253 let true_val = y_true[i];
254
255 let accuracy = if (pred - pred.round()).abs() < 1e-6
257 && (true_val - true_val.round()).abs() < 1e-6
258 {
259 if (pred.round() - true_val.round()).abs() < 1e-6 {
261 1.0
262 } else {
263 0.0
264 }
265 } else {
266 1.0 / (1.0 + (pred - true_val).abs())
268 };
269
270 scores.push(accuracy);
271 }
272
273 Ok(scores)
274 }
275
276 fn compute_decision_boundary_competence(
278 &self,
279 predictions: &Array1<f64>,
280 _y_true: &ArrayView1<'_, Float>,
281 ) -> SklResult<Vec<f64>> {
282 let mut scores = Vec::new();
284
285 for &pred in predictions {
286 let confidence = if (0.0..=1.0).contains(&pred) {
289 (pred - 0.5).abs() * 2.0
291 } else {
292 1.0 / (1.0 + pred.abs())
294 };
295
296 scores.push(confidence);
297 }
298
299 Ok(scores)
300 }
301
302 fn compute_entropy_competence(&self, predictions: &Array1<f64>) -> SklResult<Vec<f64>> {
304 let mut scores = Vec::new();
305
306 for &pred in predictions {
307 let entropy = if (0.0..=1.0).contains(&pred) {
310 let p = pred.clamp(1e-10, 1.0 - 1e-10);
312 -(p * p.ln() + (1.0 - p) * (1.0 - p).ln())
313 } else {
314 1.0 / (1.0 + pred.powi(2))
316 };
317
318 scores.push(1.0 - entropy); }
320
321 Ok(scores)
322 }
323
324 fn compute_margin_competence(
326 &self,
327 predictions: &Array1<f64>,
328 _y_true: &ArrayView1<'_, Float>,
329 ) -> SklResult<Vec<f64>> {
330 let mut scores = Vec::new();
332
333 for &pred in predictions {
334 let margin = if (0.0..=1.0).contains(&pred) {
337 (pred - 0.5).abs()
339 } else {
340 1.0 / (1.0 + pred.abs())
342 };
343
344 scores.push(margin);
345 }
346
347 Ok(scores)
348 }
349}
350
351impl DynamicEnsembleSelector<DynamicEnsembleSelectorTrained> {
352 pub fn predict(&self, x: &ArrayView2<'_, Float>) -> SklResult<Array1<f64>> {
354 let mut predictions = Vec::new();
355
356 for i in 0..x.nrows() {
357 let sample = x.row(i);
358 let selected_indices = self.select_estimators_for_sample(sample, i)?;
359
360 let mut sample_predictions = Vec::new();
362 for idx in selected_indices {
363 let pred = self.state.fitted_estimators[idx]
364 .1
365 .predict(&x.slice(s![i..i + 1, ..]))?;
366 sample_predictions.push(pred[0]);
367 }
368
369 let final_prediction = if sample_predictions.is_empty() {
371 0.0 } else {
373 sample_predictions.iter().sum::<f64>() / sample_predictions.len() as f64
374 };
375
376 predictions.push(final_prediction);
377 }
378
379 Ok(Array1::from_vec(predictions))
380 }
381
382 fn select_estimators_for_sample(
384 &self,
385 sample: ArrayView1<'_, Float>,
386 sample_idx: usize,
387 ) -> SklResult<Vec<usize>> {
388 match &self.selection_strategy {
389 SelectionStrategy::KBest { k } => self.select_k_best_estimators(*k, sample_idx),
390 SelectionStrategy::Threshold { threshold } => {
391 self.select_by_threshold(*threshold, sample_idx)
392 }
393 SelectionStrategy::AboveMedian => self.select_above_median(sample_idx),
394 SelectionStrategy::LocalCompetence { k_neighbors } => {
395 self.select_by_local_competence(&sample, *k_neighbors)
396 }
397 }
398 }
399
400 fn select_k_best_estimators(&self, k: usize, sample_idx: usize) -> SklResult<Vec<usize>> {
402 let mut estimator_scores: Vec<(usize, f64)> = self
403 .state
404 .competence_scores
405 .iter()
406 .enumerate()
407 .map(|(i, scores)| {
408 let score = scores
409 .get(sample_idx % scores.len())
410 .copied()
411 .unwrap_or(0.0);
412 (i, score)
413 })
414 .collect();
415
416 estimator_scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
417
418 let selected_k = k.min(estimator_scores.len());
419 Ok(estimator_scores
420 .into_iter()
421 .take(selected_k)
422 .map(|(idx, _)| idx)
423 .collect())
424 }
425
426 fn select_by_threshold(&self, threshold: f64, sample_idx: usize) -> SklResult<Vec<usize>> {
428 let selected: Vec<usize> = self
429 .state
430 .competence_scores
431 .iter()
432 .enumerate()
433 .filter_map(|(i, scores)| {
434 let score = scores
435 .get(sample_idx % scores.len())
436 .copied()
437 .unwrap_or(0.0);
438 if score >= threshold {
439 Some(i)
440 } else {
441 None
442 }
443 })
444 .collect();
445
446 if selected.is_empty() {
447 self.select_k_best_estimators(1, sample_idx)
449 } else {
450 Ok(selected)
451 }
452 }
453
454 fn select_above_median(&self, sample_idx: usize) -> SklResult<Vec<usize>> {
456 let scores: Vec<f64> = self
457 .state
458 .competence_scores
459 .iter()
460 .map(|scores| {
461 scores
462 .get(sample_idx % scores.len())
463 .copied()
464 .unwrap_or(0.0)
465 })
466 .collect();
467
468 let mut sorted_scores = scores.clone();
469 sorted_scores.sort_by(|a, b| a.partial_cmp(b).unwrap());
470 let median = sorted_scores[sorted_scores.len() / 2];
471
472 let selected: Vec<usize> = scores
473 .iter()
474 .enumerate()
475 .filter_map(|(i, &score)| if score >= median { Some(i) } else { None })
476 .collect();
477
478 if selected.is_empty() {
479 self.select_k_best_estimators(1, sample_idx)
480 } else {
481 Ok(selected)
482 }
483 }
484
485 fn select_by_local_competence(
487 &self,
488 sample: &ArrayView1<'_, Float>,
489 k_neighbors: usize,
490 ) -> SklResult<Vec<usize>> {
491 let mut distances: Vec<(usize, f64)> = self
493 .state
494 .validation_data
495 .rows()
496 .into_iter()
497 .enumerate()
498 .map(|(i, val_sample)| {
499 let dist = self.euclidean_distance(*sample, val_sample.mapv(|v| v as Float).view());
500 (i, dist)
501 })
502 .collect();
503
504 distances.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
505 let neighbor_indices: Vec<usize> = distances
506 .into_iter()
507 .take(k_neighbors.min(self.state.validation_data.nrows()))
508 .map(|(idx, _)| idx)
509 .collect();
510
511 let mut local_competences = vec![0.0; self.state.fitted_estimators.len()];
513 for (est_idx, scores) in self.state.competence_scores.iter().enumerate() {
514 let avg_competence = neighbor_indices
515 .iter()
516 .map(|&ni| scores.get(ni).copied().unwrap_or(0.0))
517 .sum::<f64>()
518 / neighbor_indices.len() as f64;
519 local_competences[est_idx] = avg_competence;
520 }
521
522 let avg_local_competence =
524 local_competences.iter().sum::<f64>() / local_competences.len() as f64;
525 let selected: Vec<usize> = local_competences
526 .iter()
527 .enumerate()
528 .filter_map(|(i, &comp)| {
529 if comp >= avg_local_competence {
530 Some(i)
531 } else {
532 None
533 }
534 })
535 .collect();
536
537 if selected.is_empty() {
538 let best_idx = local_competences
540 .iter()
541 .enumerate()
542 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
543 .map_or(0, |(idx, _)| idx);
544 Ok(vec![best_idx])
545 } else {
546 Ok(selected)
547 }
548 }
549
550 fn euclidean_distance(&self, a: ArrayView1<'_, Float>, b: ArrayView1<'_, Float>) -> f64 {
552 a.iter()
553 .zip(b.iter())
554 .map(|(&x, &y)| (x - y).powi(2))
555 .sum::<f64>()
556 .sqrt()
557 }
558
559 #[must_use]
561 pub fn estimators(&self) -> &[(String, Box<dyn PipelinePredictor>)] {
562 &self.state.fitted_estimators
563 }
564}
565
566pub struct DynamicEnsembleSelectorBuilder {
568 estimators: Vec<(String, Box<dyn PipelinePredictor>)>,
569 selection_strategy: SelectionStrategy,
570 competence_estimation: CompetenceEstimation,
571 validation_split: f64,
572 n_jobs: Option<i32>,
573}
574
575impl DynamicEnsembleSelectorBuilder {
576 #[must_use]
578 pub fn new() -> Self {
579 Self {
580 estimators: Vec::new(),
581 selection_strategy: SelectionStrategy::KBest { k: 3 },
582 competence_estimation: CompetenceEstimation::LocalAccuracy,
583 validation_split: 0.2,
584 n_jobs: None,
585 }
586 }
587
588 #[must_use]
590 pub fn estimator(mut self, name: &str, estimator: Box<dyn PipelinePredictor>) -> Self {
591 self.estimators.push((name.to_string(), estimator));
592 self
593 }
594
595 #[must_use]
597 pub fn selection_strategy(mut self, strategy: SelectionStrategy) -> Self {
598 self.selection_strategy = strategy;
599 self
600 }
601
602 #[must_use]
604 pub fn competence_estimation(mut self, method: CompetenceEstimation) -> Self {
605 self.competence_estimation = method;
606 self
607 }
608
609 #[must_use]
611 pub fn validation_split(mut self, split: f64) -> Self {
612 self.validation_split = split.clamp(0.1, 0.5);
613 self
614 }
615
616 #[must_use]
618 pub fn n_jobs(mut self, n_jobs: Option<i32>) -> Self {
619 self.n_jobs = n_jobs;
620 self
621 }
622
623 #[must_use]
625 pub fn build(self) -> DynamicEnsembleSelector<Untrained> {
626 DynamicEnsembleSelector {
627 state: Untrained,
628 estimators: self.estimators,
629 selection_strategy: self.selection_strategy,
630 competence_estimation: self.competence_estimation,
631 validation_split: self.validation_split,
632 n_jobs: self.n_jobs,
633 }
634 }
635}
636
637impl Default for DynamicEnsembleSelectorBuilder {
638 fn default() -> Self {
639 Self::new()
640 }
641}