1use crate::error::{Result, SklearsError};
56use rayon::prelude::*;
58use scirs2_core::ndarray::{s, Array1, Array2, ArrayView1, ArrayView2};
59use scirs2_core::random::Random;
60use serde::{Deserialize, Serialize};
61use std::collections::HashMap;
62use std::sync::{Arc, RwLock};
63use std::time::{Duration, Instant};
64
65#[derive(Debug)]
67pub struct ParallelEnsemble {
68 config: EnsembleConfig,
69 base_learners: Vec<Arc<dyn BaseEstimator>>,
70 training_state: Arc<RwLock<TrainingState>>,
71}
72
73impl ParallelEnsemble {
74 pub fn new(config: EnsembleConfig) -> Self {
76 let base_learners = Self::create_base_learners(&config);
77
78 Self {
79 config,
80 base_learners,
81 training_state: Arc::new(RwLock::new(TrainingState::new())),
82 }
83 }
84
85 fn create_base_learners(config: &EnsembleConfig) -> Vec<Arc<dyn BaseEstimator>> {
87 let mut learners = Vec::new();
88
89 for i in 0..config.n_estimators {
90 let learner: Arc<dyn BaseEstimator> = match &config.ensemble_type {
91 EnsembleType::RandomForest => {
92 Arc::new(RandomForestEstimator::new(i, &config.base_config))
93 }
94 EnsembleType::GradientBoosting => {
95 Arc::new(GradientBoostingEstimator::new(i, &config.base_config))
96 }
97 EnsembleType::AdaBoost => Arc::new(AdaBoostEstimator::new(i, &config.base_config)),
98 EnsembleType::Voting => Arc::new(VotingEstimator::new(i, &config.base_config)),
99 EnsembleType::Stacking => Arc::new(StackingEstimator::new(i, &config.base_config)),
100 };
101 learners.push(learner);
102 }
103
104 learners
105 }
106
107 pub fn n_estimators(&self) -> usize {
109 self.base_learners.len()
110 }
111
112 pub fn parallel_fit(
114 &self,
115 x: &ArrayView2<f64>,
116 y: &ArrayView1<f64>,
117 ) -> Result<TrainedParallelEnsemble> {
118 let start_time = Instant::now();
119
120 {
122 let mut state = self.training_state.write().unwrap();
123 state.start_training(x.nrows(), self.n_estimators());
124 }
125
126 let pool = rayon::ThreadPoolBuilder::new()
128 .num_threads(self.config.parallel_config.num_threads)
129 .build()
130 .map_err(|e| {
131 SklearsError::InvalidInput(format!("Failed to create thread pool: {e}"))
132 })?;
133
134 let trained_learners = pool.install(|| {
136 self.base_learners
137 .par_iter()
138 .enumerate()
139 .map(|(i, learner)| {
140 let result = self.train_single_learner(learner.as_ref(), x, y, i);
141
142 {
144 let mut state = self.training_state.write().unwrap();
145 state.update_progress(i, result.is_ok());
146 }
147
148 result
149 })
150 .collect::<Result<Vec<_>>>()
151 })?;
152
153 {
155 let mut state = self.training_state.write().unwrap();
156 state.complete_training(start_time.elapsed());
157 }
158
159 Ok(TrainedParallelEnsemble {
160 config: self.config.clone(),
161 trained_learners,
162 training_metrics: self.training_state.read().unwrap().clone(),
163 })
164 }
165
166 fn train_single_learner(
168 &self,
169 learner: &dyn BaseEstimator,
170 x: &ArrayView2<f64>,
171 y: &ArrayView1<f64>,
172 learner_id: usize,
173 ) -> Result<TrainedBaseEstimator> {
174 let (train_x, train_y) = self.prepare_training_data(x, y, learner_id)?;
176
177 let start_time = Instant::now();
179 let trained = learner.fit(&train_x.view(), &train_y.view())?;
180 let training_time = start_time.elapsed();
181
182 let training_accuracy =
184 self.compute_training_accuracy(trained.as_ref(), &train_x, &train_y)?;
185
186 Ok(TrainedBaseEstimator {
187 learner_id,
188 model: trained,
189 training_time,
190 training_accuracy,
191 })
192 }
193
194 fn prepare_training_data(
196 &self,
197 x: &ArrayView2<f64>,
198 y: &ArrayView1<f64>,
199 learner_id: usize,
200 ) -> Result<(Array2<f64>, Array1<f64>)> {
201 match self.config.sampling_strategy {
202 SamplingStrategy::Bootstrap => self.bootstrap_sample(x, y, learner_id),
203 SamplingStrategy::Bagging => self.bag_sample(x, y, learner_id),
204 SamplingStrategy::None => Ok((x.to_owned(), y.to_owned())),
205 }
206 }
207
208 fn bootstrap_sample(
210 &self,
211 x: &ArrayView2<f64>,
212 y: &ArrayView1<f64>,
213 seed: usize,
214 ) -> Result<(Array2<f64>, Array1<f64>)> {
215 let mut rng = Random::seed(self.config.random_seed + seed as u64);
216 let n_samples = x.nrows();
217
218 let mut sampled_x = Array2::zeros((n_samples, x.ncols()));
219 let mut sampled_y = Array1::zeros(n_samples);
220
221 for i in 0..n_samples {
222 let sample_idx = rng.gen_range(0..n_samples);
223 sampled_x.row_mut(i).assign(&x.row(sample_idx));
224 sampled_y[i] = y[sample_idx];
225 }
226
227 Ok((sampled_x, sampled_y))
228 }
229
230 fn bag_sample(
232 &self,
233 x: &ArrayView2<f64>,
234 y: &ArrayView1<f64>,
235 seed: usize,
236 ) -> Result<(Array2<f64>, Array1<f64>)> {
237 let mut rng = Random::seed(self.config.random_seed + seed as u64);
238 let n_samples = x.nrows();
239 let sample_size = (n_samples as f64 * self.config.subsample_ratio).round() as usize;
240
241 let mut indices: Vec<usize> = (0..n_samples).collect();
242 rng.shuffle(&mut indices);
243 indices.truncate(sample_size);
244
245 let mut sampled_x = Array2::zeros((sample_size, x.ncols()));
246 let mut sampled_y = Array1::zeros(sample_size);
247
248 for (i, &idx) in indices.iter().enumerate() {
249 sampled_x.row_mut(i).assign(&x.row(idx));
250 sampled_y[i] = y[idx];
251 }
252
253 Ok((sampled_x, sampled_y))
254 }
255
256 fn compute_training_accuracy(
258 &self,
259 model: &dyn TrainedBaseModel,
260 x: &Array2<f64>,
261 y: &Array1<f64>,
262 ) -> Result<f64> {
263 let predictions = model.predict(&x.view())?;
264
265 let correct = predictions
266 .iter()
267 .zip(y.iter())
268 .map(|(pred, actual)| {
269 if (pred - actual).abs() < 0.5 {
270 1.0
271 } else {
272 0.0
273 }
274 })
275 .sum::<f64>();
276
277 Ok(correct / y.len() as f64)
278 }
279}
280
281#[derive(Debug)]
283pub struct TrainedParallelEnsemble {
284 config: EnsembleConfig,
285 trained_learners: Vec<TrainedBaseEstimator>,
286 training_metrics: TrainingState,
287}
288
289impl TrainedParallelEnsemble {
290 pub fn n_estimators(&self) -> usize {
292 self.trained_learners.len()
293 }
294
295 pub fn training_metrics(&self) -> &TrainingState {
297 &self.training_metrics
298 }
299
300 pub fn parallel_predict(&self, x: &ArrayView2<f64>) -> Result<Array1<f64>> {
302 let n_samples = x.nrows();
303 let _n_estimators = self.trained_learners.len();
304
305 let all_predictions: Vec<Array1<f64>> = self
307 .trained_learners
308 .par_iter()
309 .map(|learner| learner.model.predict(x))
310 .collect::<Result<Vec<_>>>()?;
311
312 let mut final_predictions = Array1::zeros(n_samples);
314
315 match self.config.aggregation_method {
316 AggregationMethod::Voting => {
317 self.voting_aggregation(&all_predictions, &mut final_predictions)?;
318 }
319 AggregationMethod::Averaging => {
320 self.averaging_aggregation(&all_predictions, &mut final_predictions)?;
321 }
322 AggregationMethod::WeightedVoting => {
323 self.weighted_voting_aggregation(&all_predictions, &mut final_predictions)?;
324 }
325 AggregationMethod::Stacking => {
326 return self.stacking_aggregation(&all_predictions, x);
327 }
328 }
329
330 Ok(final_predictions)
331 }
332
333 fn voting_aggregation(
335 &self,
336 predictions: &[Array1<f64>],
337 output: &mut Array1<f64>,
338 ) -> Result<()> {
339 let n_samples = output.len();
340
341 for i in 0..n_samples {
342 let mut votes = HashMap::new();
343
344 for pred_array in predictions {
345 let vote = pred_array[i].round() as i32;
346 *votes.entry(vote).or_insert(0) += 1;
347 }
348
349 let majority_vote = votes
350 .into_iter()
351 .max_by_key(|(_, count)| *count)
352 .map(|(vote, _)| vote as f64)
353 .unwrap_or(0.0);
354
355 output[i] = majority_vote;
356 }
357
358 Ok(())
359 }
360
361 fn averaging_aggregation(
363 &self,
364 predictions: &[Array1<f64>],
365 output: &mut Array1<f64>,
366 ) -> Result<()> {
367 let n_estimators = predictions.len() as f64;
368
369 output.fill(0.0);
371 for pred_array in predictions {
372 for (out, pred) in output.iter_mut().zip(pred_array.iter()) {
373 *out += pred;
374 }
375 }
376
377 for out in output.iter_mut() {
378 *out /= n_estimators;
379 }
380
381 Ok(())
382 }
383
384 fn weighted_voting_aggregation(
386 &self,
387 predictions: &[Array1<f64>],
388 output: &mut Array1<f64>,
389 ) -> Result<()> {
390 let n_samples = output.len();
391 let weights: Vec<f64> = self
392 .trained_learners
393 .iter()
394 .map(|learner| learner.training_accuracy)
395 .collect();
396 let weight_sum: f64 = weights.iter().sum();
397
398 output.fill(0.0);
399
400 for i in 0..n_samples {
401 for (j, pred_array) in predictions.iter().enumerate() {
402 output[i] += pred_array[i] * weights[j];
403 }
404 output[i] /= weight_sum;
405 }
406
407 Ok(())
408 }
409
410 fn stacking_aggregation(
412 &self,
413 predictions: &[Array1<f64>],
414 original_features: &ArrayView2<f64>,
415 ) -> Result<Array1<f64>> {
416 let n_samples = original_features.nrows();
418 let n_base_features = original_features.ncols();
419 let n_meta_features = n_base_features + predictions.len();
420
421 let mut meta_features = Array2::zeros((n_samples, n_meta_features));
422
423 meta_features
425 .slice_mut(s![.., 0..n_base_features])
426 .assign(original_features);
427
428 for (i, pred_array) in predictions.iter().enumerate() {
430 meta_features
431 .column_mut(n_base_features + i)
432 .assign(pred_array);
433 }
434
435 let mut result = Array1::zeros(n_samples);
438 self.averaging_aggregation(predictions, &mut result)?;
439 Ok(result)
440 }
441}
442
443#[derive(Debug, Clone, Serialize, Deserialize)]
445pub struct EnsembleConfig {
446 pub ensemble_type: EnsembleType,
447 pub n_estimators: usize,
448 pub parallel_config: ParallelConfig,
449 pub sampling_strategy: SamplingStrategy,
450 pub aggregation_method: AggregationMethod,
451 pub base_config: BaseEstimatorConfig,
452 pub random_seed: u64,
453 pub subsample_ratio: f64,
454}
455
456impl EnsembleConfig {
457 pub fn random_forest() -> Self {
459 Self {
460 ensemble_type: EnsembleType::RandomForest,
461 n_estimators: 100,
462 parallel_config: ParallelConfig::default(),
463 sampling_strategy: SamplingStrategy::Bootstrap,
464 aggregation_method: AggregationMethod::Voting,
465 base_config: BaseEstimatorConfig::decision_tree(),
466 random_seed: 42,
467 subsample_ratio: 1.0,
468 }
469 }
470
471 pub fn gradient_boosting() -> Self {
473 Self {
474 ensemble_type: EnsembleType::GradientBoosting,
475 n_estimators: 100,
476 parallel_config: ParallelConfig::default(),
477 sampling_strategy: SamplingStrategy::None,
478 aggregation_method: AggregationMethod::Averaging,
479 base_config: BaseEstimatorConfig::decision_tree(),
480 random_seed: 42,
481 subsample_ratio: 1.0,
482 }
483 }
484
485 pub fn with_n_estimators(mut self, n: usize) -> Self {
487 self.n_estimators = n;
488 self
489 }
490
491 pub fn with_parallel_config(mut self, config: ParallelConfig) -> Self {
493 self.parallel_config = config;
494 self
495 }
496}
497
498#[derive(Debug, Clone, Serialize, Deserialize)]
500pub enum EnsembleType {
501 RandomForest,
502 GradientBoosting,
503 AdaBoost,
504 Voting,
505 Stacking,
506}
507
508#[derive(Debug, Clone, Serialize, Deserialize)]
510pub enum SamplingStrategy {
511 Bootstrap,
512 Bagging,
513 None,
514}
515
516#[derive(Debug, Clone, Serialize, Deserialize)]
518pub enum AggregationMethod {
519 Voting,
520 Averaging,
521 WeightedVoting,
522 Stacking,
523}
524
525#[derive(Debug, Clone, Serialize, Deserialize)]
527pub struct ParallelConfig {
528 pub num_threads: usize,
529 pub batch_size: usize,
530 pub work_stealing: bool,
531 pub load_balancing: LoadBalancingStrategy,
532}
533
534impl ParallelConfig {
535 pub fn new() -> Self {
536 Self::default()
537 }
538
539 pub fn with_num_threads(mut self, threads: usize) -> Self {
540 self.num_threads = threads;
541 self
542 }
543
544 pub fn with_batch_size(mut self, size: usize) -> Self {
545 self.batch_size = size;
546 self
547 }
548
549 pub fn with_work_stealing(mut self, enabled: bool) -> Self {
550 self.work_stealing = enabled;
551 self
552 }
553}
554
555impl Default for ParallelConfig {
556 fn default() -> Self {
557 Self {
558 num_threads: num_cpus::get(),
559 batch_size: 1000,
560 work_stealing: true,
561 load_balancing: LoadBalancingStrategy::Dynamic,
562 }
563 }
564}
565
566#[derive(Debug, Clone, Serialize, Deserialize)]
568pub enum LoadBalancingStrategy {
569 Static,
570 Dynamic,
571 WorkStealing,
572}
573
574#[derive(Debug, Clone, Serialize, Deserialize)]
576pub struct BaseEstimatorConfig {
577 pub estimator_type: BaseEstimatorType,
578 pub parameters: HashMap<String, f64>,
579}
580
581impl BaseEstimatorConfig {
582 pub fn decision_tree() -> Self {
583 let mut params = HashMap::new();
584 params.insert("max_depth".to_string(), 10.0);
585 params.insert("min_samples_split".to_string(), 2.0);
586 params.insert("min_samples_leaf".to_string(), 1.0);
587
588 Self {
589 estimator_type: BaseEstimatorType::DecisionTree,
590 parameters: params,
591 }
592 }
593}
594
595#[derive(Debug, Clone, Serialize, Deserialize)]
597pub enum BaseEstimatorType {
598 DecisionTree,
599 LinearModel,
600 NeuralNetwork,
601 SVM,
602}
603
604#[derive(Debug, Clone)]
606pub struct TrainingState {
607 pub total_estimators: usize,
608 pub completed_estimators: usize,
609 pub failed_estimators: usize,
610 pub training_start_time: Option<Instant>,
611 pub training_duration: Option<Duration>,
612 pub data_size: (usize, usize), pub parallel_efficiency: f64,
614}
615
616impl TrainingState {
617 pub fn new() -> Self {
618 Self {
619 total_estimators: 0,
620 completed_estimators: 0,
621 failed_estimators: 0,
622 training_start_time: None,
623 training_duration: None,
624 data_size: (0, 0),
625 parallel_efficiency: 0.0,
626 }
627 }
628
629 pub fn start_training(&mut self, n_samples: usize, n_estimators: usize) {
630 self.total_estimators = n_estimators;
631 self.data_size = (n_samples, 0); self.training_start_time = Some(Instant::now());
633 self.completed_estimators = 0;
634 self.failed_estimators = 0;
635 }
636
637 pub fn update_progress(&mut self, _learner_id: usize, success: bool) {
638 if success {
639 self.completed_estimators += 1;
640 } else {
641 self.failed_estimators += 1;
642 }
643 }
644
645 pub fn complete_training(&mut self, duration: Duration) {
646 self.training_duration = Some(duration);
647
648 let sequential_time_estimate = duration.as_secs_f64() * self.total_estimators as f64;
650 let actual_time = duration.as_secs_f64();
651 self.parallel_efficiency = if actual_time > 0.0 {
652 (sequential_time_estimate / actual_time).min(1.0)
653 } else {
654 0.0
655 };
656 }
657
658 pub fn progress_percentage(&self) -> f64 {
659 if self.total_estimators == 0 {
660 0.0
661 } else {
662 (self.completed_estimators as f64 / self.total_estimators as f64) * 100.0
663 }
664 }
665}
666
667impl Default for TrainingState {
668 fn default() -> Self {
669 Self::new()
670 }
671}
672
673pub trait BaseEstimator: Send + Sync + std::fmt::Debug {
675 fn fit(&self, x: &ArrayView2<f64>, y: &ArrayView1<f64>) -> Result<Box<dyn TrainedBaseModel>>;
676 fn get_config(&self) -> &BaseEstimatorConfig;
677}
678
679pub trait TrainedBaseModel: Send + Sync + std::fmt::Debug {
681 fn predict(&self, x: &ArrayView2<f64>) -> Result<Array1<f64>>;
682 fn get_importance(&self) -> Option<Array1<f64>> {
683 None
684 }
685}
686
687#[derive(Debug)]
689pub struct TrainedBaseEstimator {
690 pub learner_id: usize,
691 pub model: Box<dyn TrainedBaseModel>,
692 pub training_time: Duration,
693 pub training_accuracy: f64,
694}
695
696#[derive(Debug)]
698pub struct RandomForestEstimator {
699 id: usize,
700 config: BaseEstimatorConfig,
701}
702
703impl RandomForestEstimator {
704 pub fn new(id: usize, config: &BaseEstimatorConfig) -> Self {
705 Self {
706 id,
707 config: config.clone(),
708 }
709 }
710}
711
712impl BaseEstimator for RandomForestEstimator {
713 fn fit(&self, x: &ArrayView2<f64>, _y: &ArrayView1<f64>) -> Result<Box<dyn TrainedBaseModel>> {
714 std::thread::sleep(Duration::from_millis(10)); Ok(Box::new(TrainedRandomForestModel {
718 id: self.id,
719 feature_count: x.ncols(),
720 sample_count: x.nrows(),
721 }))
722 }
723
724 fn get_config(&self) -> &BaseEstimatorConfig {
725 &self.config
726 }
727}
728
729#[derive(Debug)]
731#[allow(dead_code)]
732pub struct TrainedRandomForestModel {
733 id: usize,
734 feature_count: usize,
735 sample_count: usize,
736}
737
738impl TrainedBaseModel for TrainedRandomForestModel {
739 fn predict(&self, x: &ArrayView2<f64>) -> Result<Array1<f64>> {
740 let mut rng = Random::seed(self.id as u64);
742
743 let predictions =
744 Array1::from_iter((0..x.nrows()).map(|_| rng.random_range(0.0_f64..3.0_f64).round()));
745
746 Ok(predictions)
747 }
748}
749
750#[derive(Debug)]
752pub struct GradientBoostingEstimator {
753 id: usize,
754 config: BaseEstimatorConfig,
755}
756
757impl GradientBoostingEstimator {
758 pub fn new(id: usize, config: &BaseEstimatorConfig) -> Self {
759 Self {
760 id,
761 config: config.clone(),
762 }
763 }
764}
765
766impl BaseEstimator for GradientBoostingEstimator {
767 fn fit(&self, x: &ArrayView2<f64>, _y: &ArrayView1<f64>) -> Result<Box<dyn TrainedBaseModel>> {
768 std::thread::sleep(Duration::from_millis(15));
769 Ok(Box::new(TrainedGradientBoostingModel {
770 id: self.id,
771 feature_count: x.ncols(),
772 }))
773 }
774
775 fn get_config(&self) -> &BaseEstimatorConfig {
776 &self.config
777 }
778}
779
780#[derive(Debug)]
781#[allow(dead_code)]
782pub struct TrainedGradientBoostingModel {
783 id: usize,
784 feature_count: usize,
785}
786
787impl TrainedBaseModel for TrainedGradientBoostingModel {
788 fn predict(&self, x: &ArrayView2<f64>) -> Result<Array1<f64>> {
789 let predictions = Array1::from_iter(x.rows().into_iter().map(|row| row.sum() * 0.1));
790 Ok(predictions)
791 }
792}
793
794macro_rules! impl_base_estimator {
796 ($estimator:ident, $model:ident, $sleep_ms:expr, $prediction_fn:expr) => {
797 #[derive(Debug)]
798 pub struct $estimator {
799 id: usize,
800 config: BaseEstimatorConfig,
801 }
802
803 impl $estimator {
804 pub fn new(id: usize, config: &BaseEstimatorConfig) -> Self {
805 Self {
806 id,
807 config: config.clone(),
808 }
809 }
810 }
811
812 impl BaseEstimator for $estimator {
813 fn fit(
814 &self,
815 x: &ArrayView2<f64>,
816 _y: &ArrayView1<f64>,
817 ) -> Result<Box<dyn TrainedBaseModel>> {
818 std::thread::sleep(Duration::from_millis($sleep_ms));
819 Ok(Box::new($model {
820 id: self.id,
821 feature_count: x.ncols(),
822 }))
823 }
824
825 fn get_config(&self) -> &BaseEstimatorConfig {
826 &self.config
827 }
828 }
829
830 #[derive(Debug)]
831 #[allow(dead_code)]
832 pub struct $model {
833 id: usize,
834 feature_count: usize,
835 }
836
837 impl TrainedBaseModel for $model {
838 fn predict(&self, x: &ArrayView2<f64>) -> Result<Array1<f64>> {
839 let predictions = Array1::from_iter(x.rows().into_iter().map($prediction_fn));
840 Ok(predictions)
841 }
842 }
843 };
844}
845
846impl_base_estimator!(AdaBoostEstimator, TrainedAdaBoostModel, 12, |row| row
847 .mean()
848 .unwrap_or(0.0));
849impl_base_estimator!(VotingEstimator, TrainedVotingModel, 8, |row| row
850 .iter()
851 .max_by(|a, b| a.partial_cmp(b).unwrap())
852 .unwrap_or(&0.0)
853 * 0.5);
854impl_base_estimator!(StackingEstimator, TrainedStackingModel, 20, |row| row.sum()
855 / row.len() as f64);
856
857#[derive(Debug)]
859pub struct DistributedEnsemble {
860 config: DistributedConfig,
861}
862
863impl DistributedEnsemble {
864 pub fn new(config: DistributedConfig) -> Self {
865 Self { config }
866 }
867
868 pub async fn join_cluster(&self) -> Result<()> {
869 println!("Joining cluster at {}", self.config.coordinator_address);
871 Ok(())
872 }
873
874 pub async fn distributed_fit(
875 &self,
876 _x: &ArrayView2<'_, f64>,
877 _y: &ArrayView1<'_, f64>,
878 ) -> Result<TrainedDistributedEnsemble> {
879 Ok(TrainedDistributedEnsemble {
881 cluster_size: self.config.cluster_size,
882 })
883 }
884}
885
886#[derive(Debug, Clone)]
888pub struct DistributedConfig {
889 pub cluster_size: usize,
890 pub node_role: NodeRole,
891 pub coordinator_address: String,
892 pub fault_tolerance: bool,
893 pub checkpointing_interval: Duration,
894}
895
896impl Default for DistributedConfig {
897 fn default() -> Self {
898 Self::new()
899 }
900}
901
902impl DistributedConfig {
903 pub fn new() -> Self {
904 Self {
905 cluster_size: 1,
906 node_role: NodeRole::Coordinator,
907 coordinator_address: "localhost:8080".to_string(),
908 fault_tolerance: false,
909 checkpointing_interval: Duration::from_secs(300),
910 }
911 }
912
913 pub fn with_cluster_size(mut self, size: usize) -> Self {
914 self.cluster_size = size;
915 self
916 }
917
918 pub fn with_node_role(mut self, role: NodeRole) -> Self {
919 self.node_role = role;
920 self
921 }
922
923 pub fn with_coordinator_address(mut self, address: &str) -> Self {
924 self.coordinator_address = address.to_string();
925 self
926 }
927
928 pub fn with_fault_tolerance(mut self, enabled: bool) -> Self {
929 self.fault_tolerance = enabled;
930 self
931 }
932
933 pub fn with_checkpointing_interval(mut self, interval: Duration) -> Self {
934 self.checkpointing_interval = interval;
935 self
936 }
937}
938
939#[derive(Debug, Clone)]
941pub enum NodeRole {
942 Coordinator,
943 Worker,
944}
945
946#[derive(Debug)]
948pub struct TrainedDistributedEnsemble {
949 cluster_size: usize,
950}
951
952impl TrainedDistributedEnsemble {
953 pub fn cluster_size(&self) -> usize {
954 self.cluster_size
955 }
956}
957
958#[allow(non_snake_case)]
959#[cfg(test)]
960mod tests {
961 use super::*;
962
963 #[test]
964 fn test_ensemble_config_creation() {
965 let config = EnsembleConfig::random_forest()
966 .with_n_estimators(50)
967 .with_parallel_config(ParallelConfig::new().with_num_threads(4));
968
969 assert_eq!(config.n_estimators, 50);
970 assert_eq!(config.parallel_config.num_threads, 4);
971 assert!(matches!(config.ensemble_type, EnsembleType::RandomForest));
972 }
973
974 #[test]
975 fn test_parallel_config() {
976 let config = ParallelConfig::new()
977 .with_num_threads(8)
978 .with_batch_size(2000)
979 .with_work_stealing(false);
980
981 assert_eq!(config.num_threads, 8);
982 assert_eq!(config.batch_size, 2000);
983 assert!(!config.work_stealing);
984 }
985
986 #[test]
987 fn test_training_state() {
988 let mut state = TrainingState::new();
989
990 state.start_training(1000, 10);
991 assert_eq!(state.total_estimators, 10);
992 assert_eq!(state.progress_percentage(), 0.0);
993
994 state.update_progress(0, true);
995 state.update_progress(1, true);
996 state.update_progress(2, false);
997
998 assert_eq!(state.completed_estimators, 2);
999 assert_eq!(state.failed_estimators, 1);
1000 assert_eq!(state.progress_percentage(), 20.0);
1001 }
1002
1003 #[test]
1004 fn test_base_estimator_creation() {
1005 let config = BaseEstimatorConfig::decision_tree();
1006 let estimator = RandomForestEstimator::new(0, &config);
1007
1008 assert!(estimator.get_config().parameters.contains_key("max_depth"));
1009 }
1010
1011 #[test]
1012 fn test_parallel_ensemble_creation() {
1013 let config = EnsembleConfig::random_forest().with_n_estimators(5);
1014 let ensemble = ParallelEnsemble::new(config);
1015
1016 assert_eq!(ensemble.n_estimators(), 5);
1017 }
1018
1019 #[test]
1020 fn test_sampling_strategies() {
1021 let config = EnsembleConfig::random_forest();
1022 let ensemble = ParallelEnsemble::new(config);
1023
1024 let x = Array2::from_shape_vec((10, 3), (0..30).map(|i| i as f64).collect()).unwrap();
1025 let y = Array1::from_shape_vec(10, (0..10).map(|i| i as f64).collect()).unwrap();
1026
1027 let (sampled_x, sampled_y) = ensemble.bootstrap_sample(&x.view(), &y.view(), 0).unwrap();
1028 assert_eq!(sampled_x.shape(), x.shape());
1029 assert_eq!(sampled_y.len(), y.len());
1030 }
1031
1032 #[test]
1033 fn test_aggregation_methods() {
1034 let config = EnsembleConfig::random_forest();
1035 let trained_learners = vec![
1036 TrainedBaseEstimator {
1037 learner_id: 0,
1038 model: Box::new(TrainedRandomForestModel {
1039 id: 0,
1040 feature_count: 3,
1041 sample_count: 10,
1042 }),
1043 training_time: Duration::from_millis(100),
1044 training_accuracy: 0.8,
1045 },
1046 TrainedBaseEstimator {
1047 learner_id: 1,
1048 model: Box::new(TrainedRandomForestModel {
1049 id: 1,
1050 feature_count: 3,
1051 sample_count: 10,
1052 }),
1053 training_time: Duration::from_millis(120),
1054 training_accuracy: 0.9,
1055 },
1056 ];
1057
1058 let ensemble = TrainedParallelEnsemble {
1059 config,
1060 trained_learners,
1061 training_metrics: TrainingState::new(),
1062 };
1063
1064 let x = Array2::zeros((5, 3));
1065 let result = ensemble.parallel_predict(&x.view());
1066 assert!(result.is_ok());
1067
1068 let predictions = result.unwrap();
1069 assert_eq!(predictions.len(), 5);
1070 }
1071
1072 #[test]
1073 fn test_distributed_config() {
1074 let config = DistributedConfig::new()
1075 .with_cluster_size(4)
1076 .with_node_role(NodeRole::Worker)
1077 .with_coordinator_address("192.168.1.100:8080")
1078 .with_fault_tolerance(true);
1079
1080 assert_eq!(config.cluster_size, 4);
1081 assert!(matches!(config.node_role, NodeRole::Worker));
1082 assert_eq!(config.coordinator_address, "192.168.1.100:8080");
1083 assert!(config.fault_tolerance);
1084 }
1085}