1use crate::error::{Result, SklearsError};
21use rayon::prelude::*;
80use scirs2_core::ndarray::{s, Array1, Array2, ArrayView1, ArrayView2};
81use scirs2_core::random::Random;
82use serde::{Deserialize, Serialize};
83use std::collections::HashMap;
84use std::sync::{Arc, RwLock};
85use std::time::{Duration, Instant};
86
87#[derive(Debug)]
89pub struct ParallelEnsemble {
90 config: EnsembleConfig,
91 base_learners: Vec<Arc<dyn BaseEstimator>>,
92 training_state: Arc<RwLock<TrainingState>>,
93}
94
95impl ParallelEnsemble {
96 pub fn new(config: EnsembleConfig) -> Self {
98 let base_learners = Self::create_base_learners(&config);
99
100 Self {
101 config,
102 base_learners,
103 training_state: Arc::new(RwLock::new(TrainingState::new())),
104 }
105 }
106
107 fn create_base_learners(config: &EnsembleConfig) -> Vec<Arc<dyn BaseEstimator>> {
109 let mut learners = Vec::new();
110
111 for i in 0..config.n_estimators {
112 let learner: Arc<dyn BaseEstimator> = match &config.ensemble_type {
113 EnsembleType::RandomForest => {
114 Arc::new(RandomForestEstimator::new(i, &config.base_config))
115 }
116 EnsembleType::GradientBoosting => {
117 Arc::new(GradientBoostingEstimator::new(i, &config.base_config))
118 }
119 EnsembleType::AdaBoost => Arc::new(AdaBoostEstimator::new(i, &config.base_config)),
120 EnsembleType::Voting => Arc::new(VotingEstimator::new(i, &config.base_config)),
121 EnsembleType::Stacking => Arc::new(StackingEstimator::new(i, &config.base_config)),
122 };
123 learners.push(learner);
124 }
125
126 learners
127 }
128
129 pub fn n_estimators(&self) -> usize {
131 self.base_learners.len()
132 }
133
134 pub fn parallel_fit(
136 &self,
137 x: &ArrayView2<f64>,
138 y: &ArrayView1<f64>,
139 ) -> Result<TrainedParallelEnsemble> {
140 let start_time = Instant::now();
141
142 {
144 let mut state = self
145 .training_state
146 .write()
147 .unwrap_or_else(|e| e.into_inner());
148 state.start_training(x.nrows(), self.n_estimators());
149 }
150
151 let pool = rayon::ThreadPoolBuilder::new()
153 .num_threads(self.config.parallel_config.num_threads)
154 .build()
155 .map_err(|e| {
156 SklearsError::InvalidInput(format!("Failed to create thread pool: {e}"))
157 })?;
158
159 let trained_learners = pool.install(|| {
161 self.base_learners
162 .par_iter()
163 .enumerate()
164 .map(|(i, learner)| {
165 let result = self.train_single_learner(learner.as_ref(), x, y, i);
166
167 {
169 let mut state = self
170 .training_state
171 .write()
172 .unwrap_or_else(|e| e.into_inner());
173 state.update_progress(i, result.is_ok());
174 }
175
176 result
177 })
178 .collect::<Result<Vec<_>>>()
179 })?;
180
181 {
183 let mut state = self
184 .training_state
185 .write()
186 .unwrap_or_else(|e| e.into_inner());
187 state.complete_training(start_time.elapsed());
188 }
189
190 Ok(TrainedParallelEnsemble {
191 config: self.config.clone(),
192 trained_learners,
193 training_metrics: self
194 .training_state
195 .read()
196 .unwrap_or_else(|e| e.into_inner())
197 .clone(),
198 })
199 }
200
201 fn train_single_learner(
203 &self,
204 learner: &dyn BaseEstimator,
205 x: &ArrayView2<f64>,
206 y: &ArrayView1<f64>,
207 learner_id: usize,
208 ) -> Result<TrainedBaseEstimator> {
209 let (train_x, train_y) = self.prepare_training_data(x, y, learner_id)?;
211
212 let start_time = Instant::now();
214 let trained = learner.fit(&train_x.view(), &train_y.view())?;
215 let training_time = start_time.elapsed();
216
217 let training_accuracy =
219 self.compute_training_accuracy(trained.as_ref(), &train_x, &train_y)?;
220
221 Ok(TrainedBaseEstimator {
222 learner_id,
223 model: trained,
224 training_time,
225 training_accuracy,
226 })
227 }
228
229 fn prepare_training_data(
231 &self,
232 x: &ArrayView2<f64>,
233 y: &ArrayView1<f64>,
234 learner_id: usize,
235 ) -> Result<(Array2<f64>, Array1<f64>)> {
236 match self.config.sampling_strategy {
237 SamplingStrategy::Bootstrap => self.bootstrap_sample(x, y, learner_id),
238 SamplingStrategy::Bagging => self.bag_sample(x, y, learner_id),
239 SamplingStrategy::None => Ok((x.to_owned(), y.to_owned())),
240 }
241 }
242
243 fn bootstrap_sample(
245 &self,
246 x: &ArrayView2<f64>,
247 y: &ArrayView1<f64>,
248 seed: usize,
249 ) -> Result<(Array2<f64>, Array1<f64>)> {
250 let mut rng = Random::seed(self.config.random_seed + seed as u64);
251 let n_samples = x.nrows();
252
253 let mut sampled_x = Array2::zeros((n_samples, x.ncols()));
254 let mut sampled_y = Array1::zeros(n_samples);
255
256 for i in 0..n_samples {
257 let sample_idx = rng.gen_range(0..n_samples);
258 sampled_x.row_mut(i).assign(&x.row(sample_idx));
259 sampled_y[i] = y[sample_idx];
260 }
261
262 Ok((sampled_x, sampled_y))
263 }
264
265 fn bag_sample(
267 &self,
268 x: &ArrayView2<f64>,
269 y: &ArrayView1<f64>,
270 seed: usize,
271 ) -> Result<(Array2<f64>, Array1<f64>)> {
272 let mut rng = Random::seed(self.config.random_seed + seed as u64);
273 let n_samples = x.nrows();
274 let sample_size = (n_samples as f64 * self.config.subsample_ratio).round() as usize;
275
276 let mut indices: Vec<usize> = (0..n_samples).collect();
277 rng.shuffle(&mut indices);
278 indices.truncate(sample_size);
279
280 let mut sampled_x = Array2::zeros((sample_size, x.ncols()));
281 let mut sampled_y = Array1::zeros(sample_size);
282
283 for (i, &idx) in indices.iter().enumerate() {
284 sampled_x.row_mut(i).assign(&x.row(idx));
285 sampled_y[i] = y[idx];
286 }
287
288 Ok((sampled_x, sampled_y))
289 }
290
291 fn compute_training_accuracy(
293 &self,
294 model: &dyn TrainedBaseModel,
295 x: &Array2<f64>,
296 y: &Array1<f64>,
297 ) -> Result<f64> {
298 let predictions = model.predict(&x.view())?;
299
300 let correct = predictions
301 .iter()
302 .zip(y.iter())
303 .map(|(pred, actual)| {
304 if (pred - actual).abs() < 0.5 {
305 1.0
306 } else {
307 0.0
308 }
309 })
310 .sum::<f64>();
311
312 Ok(correct / y.len() as f64)
313 }
314}
315
316#[derive(Debug)]
318pub struct TrainedParallelEnsemble {
319 config: EnsembleConfig,
320 trained_learners: Vec<TrainedBaseEstimator>,
321 training_metrics: TrainingState,
322}
323
324impl TrainedParallelEnsemble {
325 pub fn n_estimators(&self) -> usize {
327 self.trained_learners.len()
328 }
329
330 pub fn training_metrics(&self) -> &TrainingState {
332 &self.training_metrics
333 }
334
335 pub fn parallel_predict(&self, x: &ArrayView2<f64>) -> Result<Array1<f64>> {
337 let n_samples = x.nrows();
338 let _n_estimators = self.trained_learners.len();
339
340 let all_predictions: Vec<Array1<f64>> = self
342 .trained_learners
343 .par_iter()
344 .map(|learner| learner.model.predict(x))
345 .collect::<Result<Vec<_>>>()?;
346
347 let mut final_predictions = Array1::zeros(n_samples);
349
350 match self.config.aggregation_method {
351 AggregationMethod::Voting => {
352 self.voting_aggregation(&all_predictions, &mut final_predictions)?;
353 }
354 AggregationMethod::Averaging => {
355 self.averaging_aggregation(&all_predictions, &mut final_predictions)?;
356 }
357 AggregationMethod::WeightedVoting => {
358 self.weighted_voting_aggregation(&all_predictions, &mut final_predictions)?;
359 }
360 AggregationMethod::Stacking => {
361 return self.stacking_aggregation(&all_predictions, x);
362 }
363 }
364
365 Ok(final_predictions)
366 }
367
368 fn voting_aggregation(
370 &self,
371 predictions: &[Array1<f64>],
372 output: &mut Array1<f64>,
373 ) -> Result<()> {
374 let n_samples = output.len();
375
376 for i in 0..n_samples {
377 let mut votes = HashMap::new();
378
379 for pred_array in predictions {
380 let vote = pred_array[i].round() as i32;
381 *votes.entry(vote).or_insert(0) += 1;
382 }
383
384 let majority_vote = votes
385 .into_iter()
386 .max_by_key(|(_, count)| *count)
387 .map(|(vote, _)| vote as f64)
388 .unwrap_or(0.0);
389
390 output[i] = majority_vote;
391 }
392
393 Ok(())
394 }
395
396 fn averaging_aggregation(
398 &self,
399 predictions: &[Array1<f64>],
400 output: &mut Array1<f64>,
401 ) -> Result<()> {
402 let n_estimators = predictions.len() as f64;
403
404 output.fill(0.0);
406 for pred_array in predictions {
407 for (out, pred) in output.iter_mut().zip(pred_array.iter()) {
408 *out += pred;
409 }
410 }
411
412 for out in output.iter_mut() {
413 *out /= n_estimators;
414 }
415
416 Ok(())
417 }
418
419 fn weighted_voting_aggregation(
421 &self,
422 predictions: &[Array1<f64>],
423 output: &mut Array1<f64>,
424 ) -> Result<()> {
425 let n_samples = output.len();
426 let weights: Vec<f64> = self
427 .trained_learners
428 .iter()
429 .map(|learner| learner.training_accuracy)
430 .collect();
431 let weight_sum: f64 = weights.iter().sum();
432
433 output.fill(0.0);
434
435 for i in 0..n_samples {
436 for (j, pred_array) in predictions.iter().enumerate() {
437 output[i] += pred_array[i] * weights[j];
438 }
439 output[i] /= weight_sum;
440 }
441
442 Ok(())
443 }
444
445 fn stacking_aggregation(
447 &self,
448 predictions: &[Array1<f64>],
449 original_features: &ArrayView2<f64>,
450 ) -> Result<Array1<f64>> {
451 let n_samples = original_features.nrows();
453 let n_base_features = original_features.ncols();
454 let n_meta_features = n_base_features + predictions.len();
455
456 let mut meta_features = Array2::zeros((n_samples, n_meta_features));
457
458 meta_features
460 .slice_mut(s![.., 0..n_base_features])
461 .assign(original_features);
462
463 for (i, pred_array) in predictions.iter().enumerate() {
465 meta_features
466 .column_mut(n_base_features + i)
467 .assign(pred_array);
468 }
469
470 let mut result = Array1::zeros(n_samples);
473 self.averaging_aggregation(predictions, &mut result)?;
474 Ok(result)
475 }
476}
477
478#[derive(Debug, Clone, Serialize, Deserialize)]
480pub struct EnsembleConfig {
481 pub ensemble_type: EnsembleType,
482 pub n_estimators: usize,
483 pub parallel_config: ParallelConfig,
484 pub sampling_strategy: SamplingStrategy,
485 pub aggregation_method: AggregationMethod,
486 pub base_config: BaseEstimatorConfig,
487 pub random_seed: u64,
488 pub subsample_ratio: f64,
489}
490
491impl EnsembleConfig {
492 pub fn random_forest() -> Self {
494 Self {
495 ensemble_type: EnsembleType::RandomForest,
496 n_estimators: 100,
497 parallel_config: ParallelConfig::default(),
498 sampling_strategy: SamplingStrategy::Bootstrap,
499 aggregation_method: AggregationMethod::Voting,
500 base_config: BaseEstimatorConfig::decision_tree(),
501 random_seed: 42,
502 subsample_ratio: 1.0,
503 }
504 }
505
506 pub fn gradient_boosting() -> Self {
508 Self {
509 ensemble_type: EnsembleType::GradientBoosting,
510 n_estimators: 100,
511 parallel_config: ParallelConfig::default(),
512 sampling_strategy: SamplingStrategy::None,
513 aggregation_method: AggregationMethod::Averaging,
514 base_config: BaseEstimatorConfig::decision_tree(),
515 random_seed: 42,
516 subsample_ratio: 1.0,
517 }
518 }
519
520 pub fn with_n_estimators(mut self, n: usize) -> Self {
522 self.n_estimators = n;
523 self
524 }
525
526 pub fn with_parallel_config(mut self, config: ParallelConfig) -> Self {
528 self.parallel_config = config;
529 self
530 }
531}
532
533#[derive(Debug, Clone, Serialize, Deserialize)]
535pub enum EnsembleType {
536 RandomForest,
537 GradientBoosting,
538 AdaBoost,
539 Voting,
540 Stacking,
541}
542
543#[derive(Debug, Clone, Serialize, Deserialize)]
545pub enum SamplingStrategy {
546 Bootstrap,
547 Bagging,
548 None,
549}
550
551#[derive(Debug, Clone, Serialize, Deserialize)]
553pub enum AggregationMethod {
554 Voting,
555 Averaging,
556 WeightedVoting,
557 Stacking,
558}
559
560#[derive(Debug, Clone, Serialize, Deserialize)]
562pub struct ParallelConfig {
563 pub num_threads: usize,
564 pub batch_size: usize,
565 pub work_stealing: bool,
566 pub load_balancing: LoadBalancingStrategy,
567}
568
569impl ParallelConfig {
570 pub fn new() -> Self {
571 Self::default()
572 }
573
574 pub fn with_num_threads(mut self, threads: usize) -> Self {
575 self.num_threads = threads;
576 self
577 }
578
579 pub fn with_batch_size(mut self, size: usize) -> Self {
580 self.batch_size = size;
581 self
582 }
583
584 pub fn with_work_stealing(mut self, enabled: bool) -> Self {
585 self.work_stealing = enabled;
586 self
587 }
588}
589
590impl Default for ParallelConfig {
591 fn default() -> Self {
592 Self {
593 num_threads: num_cpus::get(),
594 batch_size: 1000,
595 work_stealing: true,
596 load_balancing: LoadBalancingStrategy::Dynamic,
597 }
598 }
599}
600
601#[derive(Debug, Clone, Serialize, Deserialize)]
603pub enum LoadBalancingStrategy {
604 Static,
605 Dynamic,
606 WorkStealing,
607}
608
609#[derive(Debug, Clone, Serialize, Deserialize)]
611pub struct BaseEstimatorConfig {
612 pub estimator_type: BaseEstimatorType,
613 pub parameters: HashMap<String, f64>,
614}
615
616impl BaseEstimatorConfig {
617 pub fn decision_tree() -> Self {
618 let mut params = HashMap::new();
619 params.insert("max_depth".to_string(), 10.0);
620 params.insert("min_samples_split".to_string(), 2.0);
621 params.insert("min_samples_leaf".to_string(), 1.0);
622
623 Self {
624 estimator_type: BaseEstimatorType::DecisionTree,
625 parameters: params,
626 }
627 }
628}
629
630#[derive(Debug, Clone, Serialize, Deserialize)]
632pub enum BaseEstimatorType {
633 DecisionTree,
634 LinearModel,
635 NeuralNetwork,
636 SVM,
637}
638
639#[derive(Debug, Clone)]
641pub struct TrainingState {
642 pub total_estimators: usize,
643 pub completed_estimators: usize,
644 pub failed_estimators: usize,
645 pub training_start_time: Option<Instant>,
646 pub training_duration: Option<Duration>,
647 pub data_size: (usize, usize), pub parallel_efficiency: f64,
649}
650
651impl TrainingState {
652 pub fn new() -> Self {
653 Self {
654 total_estimators: 0,
655 completed_estimators: 0,
656 failed_estimators: 0,
657 training_start_time: None,
658 training_duration: None,
659 data_size: (0, 0),
660 parallel_efficiency: 0.0,
661 }
662 }
663
664 pub fn start_training(&mut self, n_samples: usize, n_estimators: usize) {
665 self.total_estimators = n_estimators;
666 self.data_size = (n_samples, 0); self.training_start_time = Some(Instant::now());
668 self.completed_estimators = 0;
669 self.failed_estimators = 0;
670 }
671
672 pub fn update_progress(&mut self, _learner_id: usize, success: bool) {
673 if success {
674 self.completed_estimators += 1;
675 } else {
676 self.failed_estimators += 1;
677 }
678 }
679
680 pub fn complete_training(&mut self, duration: Duration) {
681 self.training_duration = Some(duration);
682
683 let sequential_time_estimate = duration.as_secs_f64() * self.total_estimators as f64;
685 let actual_time = duration.as_secs_f64();
686 self.parallel_efficiency = if actual_time > 0.0 {
687 (sequential_time_estimate / actual_time).min(1.0)
688 } else {
689 0.0
690 };
691 }
692
693 pub fn progress_percentage(&self) -> f64 {
694 if self.total_estimators == 0 {
695 0.0
696 } else {
697 (self.completed_estimators as f64 / self.total_estimators as f64) * 100.0
698 }
699 }
700}
701
702impl Default for TrainingState {
703 fn default() -> Self {
704 Self::new()
705 }
706}
707
708pub trait BaseEstimator: Send + Sync + std::fmt::Debug {
710 fn fit(&self, x: &ArrayView2<f64>, y: &ArrayView1<f64>) -> Result<Box<dyn TrainedBaseModel>>;
711 fn get_config(&self) -> &BaseEstimatorConfig;
712}
713
714pub trait TrainedBaseModel: Send + Sync + std::fmt::Debug {
716 fn predict(&self, x: &ArrayView2<f64>) -> Result<Array1<f64>>;
717 fn get_importance(&self) -> Option<Array1<f64>> {
718 None
719 }
720}
721
722#[derive(Debug)]
724pub struct TrainedBaseEstimator {
725 pub learner_id: usize,
726 pub model: Box<dyn TrainedBaseModel>,
727 pub training_time: Duration,
728 pub training_accuracy: f64,
729}
730
731#[derive(Debug)]
733pub struct RandomForestEstimator {
734 id: usize,
735 config: BaseEstimatorConfig,
736}
737
738impl RandomForestEstimator {
739 pub fn new(id: usize, config: &BaseEstimatorConfig) -> Self {
740 Self {
741 id,
742 config: config.clone(),
743 }
744 }
745}
746
747impl BaseEstimator for RandomForestEstimator {
748 fn fit(&self, x: &ArrayView2<f64>, _y: &ArrayView1<f64>) -> Result<Box<dyn TrainedBaseModel>> {
749 std::thread::sleep(Duration::from_millis(10)); Ok(Box::new(TrainedRandomForestModel {
753 id: self.id,
754 feature_count: x.ncols(),
755 sample_count: x.nrows(),
756 }))
757 }
758
759 fn get_config(&self) -> &BaseEstimatorConfig {
760 &self.config
761 }
762}
763
764#[derive(Debug)]
766#[allow(dead_code)]
767pub struct TrainedRandomForestModel {
768 id: usize,
769 feature_count: usize,
770 sample_count: usize,
771}
772
773impl TrainedBaseModel for TrainedRandomForestModel {
774 fn predict(&self, x: &ArrayView2<f64>) -> Result<Array1<f64>> {
775 let mut rng = Random::seed(self.id as u64);
777
778 let predictions =
779 Array1::from_iter((0..x.nrows()).map(|_| rng.random_range(0.0_f64..3.0_f64).round()));
780
781 Ok(predictions)
782 }
783}
784
785#[derive(Debug)]
787pub struct GradientBoostingEstimator {
788 id: usize,
789 config: BaseEstimatorConfig,
790}
791
792impl GradientBoostingEstimator {
793 pub fn new(id: usize, config: &BaseEstimatorConfig) -> Self {
794 Self {
795 id,
796 config: config.clone(),
797 }
798 }
799}
800
801impl BaseEstimator for GradientBoostingEstimator {
802 fn fit(&self, x: &ArrayView2<f64>, _y: &ArrayView1<f64>) -> Result<Box<dyn TrainedBaseModel>> {
803 std::thread::sleep(Duration::from_millis(15));
804 Ok(Box::new(TrainedGradientBoostingModel {
805 id: self.id,
806 feature_count: x.ncols(),
807 }))
808 }
809
810 fn get_config(&self) -> &BaseEstimatorConfig {
811 &self.config
812 }
813}
814
815#[derive(Debug)]
816#[allow(dead_code)]
817pub struct TrainedGradientBoostingModel {
818 id: usize,
819 feature_count: usize,
820}
821
822impl TrainedBaseModel for TrainedGradientBoostingModel {
823 fn predict(&self, x: &ArrayView2<f64>) -> Result<Array1<f64>> {
824 let predictions = Array1::from_iter(x.rows().into_iter().map(|row| row.sum() * 0.1));
825 Ok(predictions)
826 }
827}
828
829macro_rules! impl_base_estimator {
831 ($estimator:ident, $model:ident, $sleep_ms:expr, $prediction_fn:expr) => {
832 #[derive(Debug)]
833 pub struct $estimator {
834 id: usize,
835 config: BaseEstimatorConfig,
836 }
837
838 impl $estimator {
839 pub fn new(id: usize, config: &BaseEstimatorConfig) -> Self {
840 Self {
841 id,
842 config: config.clone(),
843 }
844 }
845 }
846
847 impl BaseEstimator for $estimator {
848 fn fit(
849 &self,
850 x: &ArrayView2<f64>,
851 _y: &ArrayView1<f64>,
852 ) -> Result<Box<dyn TrainedBaseModel>> {
853 std::thread::sleep(Duration::from_millis($sleep_ms));
854 Ok(Box::new($model {
855 id: self.id,
856 feature_count: x.ncols(),
857 }))
858 }
859
860 fn get_config(&self) -> &BaseEstimatorConfig {
861 &self.config
862 }
863 }
864
865 #[derive(Debug)]
866 #[allow(dead_code)]
867 pub struct $model {
868 id: usize,
869 feature_count: usize,
870 }
871
872 impl TrainedBaseModel for $model {
873 fn predict(&self, x: &ArrayView2<f64>) -> Result<Array1<f64>> {
874 let predictions = Array1::from_iter(x.rows().into_iter().map($prediction_fn));
875 Ok(predictions)
876 }
877 }
878 };
879}
880
881impl_base_estimator!(AdaBoostEstimator, TrainedAdaBoostModel, 12, |row| row
882 .mean()
883 .unwrap_or(0.0));
884impl_base_estimator!(VotingEstimator, TrainedVotingModel, 8, |row| row
885 .iter()
886 .max_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
887 .unwrap_or(&0.0)
888 * 0.5);
889impl_base_estimator!(StackingEstimator, TrainedStackingModel, 20, |row| row.sum()
890 / row.len() as f64);
891
892#[derive(Debug)]
894pub struct DistributedEnsemble {
895 config: DistributedConfig,
896}
897
898impl DistributedEnsemble {
899 pub fn new(config: DistributedConfig) -> Self {
900 Self { config }
901 }
902
903 pub async fn join_cluster(&self) -> Result<()> {
904 println!("Joining cluster at {}", self.config.coordinator_address);
906 Ok(())
907 }
908
909 pub async fn distributed_fit(
910 &self,
911 _x: &ArrayView2<'_, f64>,
912 _y: &ArrayView1<'_, f64>,
913 ) -> Result<TrainedDistributedEnsemble> {
914 Ok(TrainedDistributedEnsemble {
916 cluster_size: self.config.cluster_size,
917 })
918 }
919}
920
921#[derive(Debug, Clone)]
923pub struct DistributedConfig {
924 pub cluster_size: usize,
925 pub node_role: NodeRole,
926 pub coordinator_address: String,
927 pub fault_tolerance: bool,
928 pub checkpointing_interval: Duration,
929}
930
931impl Default for DistributedConfig {
932 fn default() -> Self {
933 Self::new()
934 }
935}
936
937impl DistributedConfig {
938 pub fn new() -> Self {
939 Self {
940 cluster_size: 1,
941 node_role: NodeRole::Coordinator,
942 coordinator_address: "localhost:8080".to_string(),
943 fault_tolerance: false,
944 checkpointing_interval: Duration::from_secs(300),
945 }
946 }
947
948 pub fn with_cluster_size(mut self, size: usize) -> Self {
949 self.cluster_size = size;
950 self
951 }
952
953 pub fn with_node_role(mut self, role: NodeRole) -> Self {
954 self.node_role = role;
955 self
956 }
957
958 pub fn with_coordinator_address(mut self, address: &str) -> Self {
959 self.coordinator_address = address.to_string();
960 self
961 }
962
963 pub fn with_fault_tolerance(mut self, enabled: bool) -> Self {
964 self.fault_tolerance = enabled;
965 self
966 }
967
968 pub fn with_checkpointing_interval(mut self, interval: Duration) -> Self {
969 self.checkpointing_interval = interval;
970 self
971 }
972}
973
974#[derive(Debug, Clone)]
976pub enum NodeRole {
977 Coordinator,
978 Worker,
979}
980
981#[derive(Debug)]
983pub struct TrainedDistributedEnsemble {
984 cluster_size: usize,
985}
986
987impl TrainedDistributedEnsemble {
988 pub fn cluster_size(&self) -> usize {
989 self.cluster_size
990 }
991}
992
993#[allow(non_snake_case)]
994#[cfg(test)]
995mod tests {
996 use super::*;
997
998 #[test]
999 fn test_ensemble_config_creation() {
1000 let config = EnsembleConfig::random_forest()
1001 .with_n_estimators(50)
1002 .with_parallel_config(ParallelConfig::new().with_num_threads(4));
1003
1004 assert_eq!(config.n_estimators, 50);
1005 assert_eq!(config.parallel_config.num_threads, 4);
1006 assert!(matches!(config.ensemble_type, EnsembleType::RandomForest));
1007 }
1008
1009 #[test]
1010 fn test_parallel_config() {
1011 let config = ParallelConfig::new()
1012 .with_num_threads(8)
1013 .with_batch_size(2000)
1014 .with_work_stealing(false);
1015
1016 assert_eq!(config.num_threads, 8);
1017 assert_eq!(config.batch_size, 2000);
1018 assert!(!config.work_stealing);
1019 }
1020
1021 #[test]
1022 fn test_training_state() {
1023 let mut state = TrainingState::new();
1024
1025 state.start_training(1000, 10);
1026 assert_eq!(state.total_estimators, 10);
1027 assert_eq!(state.progress_percentage(), 0.0);
1028
1029 state.update_progress(0, true);
1030 state.update_progress(1, true);
1031 state.update_progress(2, false);
1032
1033 assert_eq!(state.completed_estimators, 2);
1034 assert_eq!(state.failed_estimators, 1);
1035 assert_eq!(state.progress_percentage(), 20.0);
1036 }
1037
1038 #[test]
1039 fn test_base_estimator_creation() {
1040 let config = BaseEstimatorConfig::decision_tree();
1041 let estimator = RandomForestEstimator::new(0, &config);
1042
1043 assert!(estimator.get_config().parameters.contains_key("max_depth"));
1044 }
1045
1046 #[test]
1047 fn test_parallel_ensemble_creation() {
1048 let config = EnsembleConfig::random_forest().with_n_estimators(5);
1049 let ensemble = ParallelEnsemble::new(config);
1050
1051 assert_eq!(ensemble.n_estimators(), 5);
1052 }
1053
1054 #[test]
1055 fn test_sampling_strategies() {
1056 let config = EnsembleConfig::random_forest();
1057 let ensemble = ParallelEnsemble::new(config);
1058
1059 let x = Array2::from_shape_vec((10, 3), (0..30).map(|i| i as f64).collect())
1060 .expect("valid array shape");
1061 let y = Array1::from_shape_vec(10, (0..10).map(|i| i as f64).collect())
1062 .expect("valid array shape");
1063
1064 let (sampled_x, sampled_y) = ensemble
1065 .bootstrap_sample(&x.view(), &y.view(), 0)
1066 .expect("expected valid value");
1067 assert_eq!(sampled_x.shape(), x.shape());
1068 assert_eq!(sampled_y.len(), y.len());
1069 }
1070
1071 #[test]
1072 fn test_aggregation_methods() {
1073 let config = EnsembleConfig::random_forest();
1074 let trained_learners = vec![
1075 TrainedBaseEstimator {
1076 learner_id: 0,
1077 model: Box::new(TrainedRandomForestModel {
1078 id: 0,
1079 feature_count: 3,
1080 sample_count: 10,
1081 }),
1082 training_time: Duration::from_millis(100),
1083 training_accuracy: 0.8,
1084 },
1085 TrainedBaseEstimator {
1086 learner_id: 1,
1087 model: Box::new(TrainedRandomForestModel {
1088 id: 1,
1089 feature_count: 3,
1090 sample_count: 10,
1091 }),
1092 training_time: Duration::from_millis(120),
1093 training_accuracy: 0.9,
1094 },
1095 ];
1096
1097 let ensemble = TrainedParallelEnsemble {
1098 config,
1099 trained_learners,
1100 training_metrics: TrainingState::new(),
1101 };
1102
1103 let x = Array2::zeros((5, 3));
1104 let result = ensemble.parallel_predict(&x.view());
1105 assert!(result.is_ok());
1106
1107 let predictions = result.expect("expected valid value");
1108 assert_eq!(predictions.len(), 5);
1109 }
1110
1111 #[test]
1112 fn test_distributed_config() {
1113 let config = DistributedConfig::new()
1114 .with_cluster_size(4)
1115 .with_node_role(NodeRole::Worker)
1116 .with_coordinator_address("192.168.1.100:8080")
1117 .with_fault_tolerance(true);
1118
1119 assert_eq!(config.cluster_size, 4);
1120 assert!(matches!(config.node_role, NodeRole::Worker));
1121 assert_eq!(config.coordinator_address, "192.168.1.100:8080");
1122 assert!(config.fault_tolerance);
1123 }
1124}