1use sklears_core::error::{Result, SklearsError};
8use std::thread;
9use std::time::Instant;
10
11#[cfg(feature = "parallel")]
12use rayon::prelude::*;
13#[cfg(feature = "parallel")]
14use rayon::ThreadPoolBuilder;
15
16#[derive(Debug, Clone, Copy, PartialEq)]
18pub enum ParallelStrategy {
19 DataParallel,
21 ModelParallel,
23 EnsembleParallel,
25 Hybrid,
27}
28
29#[derive(Debug, Clone)]
31pub struct ParallelConfig {
32 pub n_workers: Option<usize>,
34 pub strategy: ParallelStrategy,
36 pub batch_size: Option<usize>,
38 pub work_stealing: bool,
40 pub thread_pool_size: Option<usize>,
42 pub memory_limit_mb: Option<usize>,
44 pub load_balancing: bool,
46 pub communication_buffer_size: usize,
48 pub worker_timeout_secs: Option<u64>,
50}
51
52impl Default for ParallelConfig {
53 fn default() -> Self {
54 Self {
55 n_workers: None, strategy: ParallelStrategy::DataParallel,
57 batch_size: None, work_stealing: true,
59 thread_pool_size: None,
60 memory_limit_mb: None,
61 load_balancing: true,
62 communication_buffer_size: 1024,
63 worker_timeout_secs: None,
64 }
65 }
66}
67
68pub struct ParallelEnsembleTrainer {
70 config: ParallelConfig,
71 performance_metrics: ParallelPerformanceMetrics,
72}
73
74#[derive(Debug, Clone)]
76pub struct ParallelPerformanceMetrics {
77 pub total_time_secs: f64,
79 pub parallel_time_secs: f64,
81 pub sync_time_secs: f64,
83 pub workers_used: usize,
85 pub parallel_efficiency: f64,
87 pub memory_usage_mb: Vec<f64>,
89 pub load_balance_efficiency: f64,
91}
92
93impl Default for ParallelPerformanceMetrics {
94 fn default() -> Self {
95 Self {
96 total_time_secs: 0.0,
97 parallel_time_secs: 0.0,
98 sync_time_secs: 0.0,
99 workers_used: 0,
100 parallel_efficiency: 0.0,
101 memory_usage_mb: Vec::new(),
102 load_balance_efficiency: 0.0,
103 }
104 }
105}
106
107pub trait ParallelTrainable<X, Y> {
109 type Output;
110
111 fn train_single(&self, x: &X, y: &Y, worker_id: usize) -> Result<Self::Output>;
113
114 fn combine_results(results: Vec<Self::Output>) -> Result<Self::Output>;
116
117 fn estimate_memory_usage(&self, data_size: usize) -> usize;
119}
120
121#[derive(Debug, Clone)]
123pub struct DataPartition {
124 pub start_idx: usize,
126 pub end_idx: usize,
128 pub worker_id: usize,
130 pub memory_estimate_mb: f64,
132}
133
134impl ParallelEnsembleTrainer {
135 pub fn new(config: ParallelConfig) -> Self {
137 Self {
138 config,
139 performance_metrics: ParallelPerformanceMetrics::default(),
140 }
141 }
142
143 pub fn auto() -> Self {
145 let n_workers = thread::available_parallelism()
146 .map(|n| n.get())
147 .unwrap_or(4);
148
149 Self::new(ParallelConfig {
150 n_workers: Some(n_workers),
151 ..Default::default()
152 })
153 }
154
155 pub fn train_data_parallel<T, X, Y>(
157 &mut self,
158 trainer: &T,
159 x: &X,
160 y: &Y,
161 n_estimators: usize,
162 ) -> Result<Vec<T::Output>>
163 where
164 T: ParallelTrainable<X, Y> + Sync + Send,
165 X: Clone + Send + Sync,
166 Y: Clone + Send + Sync,
167 T::Output: Send,
168 {
169 let start_time = Instant::now();
170
171 let n_workers = self.config.n_workers.unwrap_or_else(|| {
173 thread::available_parallelism()
174 .map(|n| n.get())
175 .unwrap_or(4)
176 });
177
178 self.performance_metrics.workers_used = n_workers;
179
180 let partitions = self.create_data_partitions(n_estimators, n_workers)?;
182
183 let parallel_start = Instant::now();
184
185 #[cfg(feature = "parallel")]
186 {
187 let pool = if let Some(pool_size) = self.config.thread_pool_size {
189 ThreadPoolBuilder::new()
190 .num_threads(pool_size)
191 .build()
192 .map_err(|e| {
193 SklearsError::InvalidInput(format!("Failed to create thread pool: {}", e))
194 })?
195 } else {
196 rayon::ThreadPoolBuilder::new()
197 .num_threads(n_workers)
198 .build()
199 .map_err(|e| {
200 SklearsError::InvalidInput(format!("Failed to create thread pool: {}", e))
201 })?
202 };
203
204 let results: Result<Vec<_>> = pool.install(|| {
206 partitions
207 .into_par_iter()
208 .map(|partition| trainer.train_single(x, y, partition.worker_id))
209 .collect()
210 });
211
212 let parallel_results = results?;
213
214 self.performance_metrics.parallel_time_secs = parallel_start.elapsed().as_secs_f64();
215
216 let sync_start = Instant::now();
218
219 self.performance_metrics.sync_time_secs = sync_start.elapsed().as_secs_f64();
223 self.performance_metrics.total_time_secs = start_time.elapsed().as_secs_f64();
224
225 self.calculate_efficiency_metrics();
227
228 Ok(parallel_results)
229 }
230
231 #[cfg(not(feature = "parallel"))]
232 {
233 let mut results = Vec::new();
235 for partition in partitions {
236 let result = trainer.train_single(x, y, partition.worker_id)?;
237 results.push(result);
238 }
239
240 self.performance_metrics.parallel_time_secs = parallel_start.elapsed().as_secs_f64();
241 self.performance_metrics.total_time_secs = start_time.elapsed().as_secs_f64();
242
243 Ok(results)
244 }
245 }
246
247 pub fn train_model_parallel<T, X, Y>(
249 &mut self,
250 trainers: Vec<&T>,
251 x: &X,
252 y: &Y,
253 ) -> Result<Vec<T::Output>>
254 where
255 T: ParallelTrainable<X, Y> + Sync + Send,
256 X: Clone + Send + Sync,
257 Y: Clone + Send + Sync,
258 T::Output: Send,
259 {
260 let start_time = Instant::now();
261
262 #[cfg(feature = "parallel")]
263 {
264 let results: Result<Vec<_>> = trainers
265 .into_par_iter()
266 .enumerate()
267 .map(|(worker_id, trainer)| trainer.train_single(x, y, worker_id))
268 .collect();
269
270 self.performance_metrics.total_time_secs = start_time.elapsed().as_secs_f64();
271 results
272 }
273
274 #[cfg(not(feature = "parallel"))]
275 {
276 let mut results = Vec::new();
277 for (worker_id, trainer) in trainers.into_iter().enumerate() {
278 let result = trainer.train_single(x, y, worker_id)?;
279 results.push(result);
280 }
281
282 self.performance_metrics.total_time_secs = start_time.elapsed().as_secs_f64();
283 Ok(results)
284 }
285 }
286
287 fn create_data_partitions(
289 &self,
290 n_estimators: usize,
291 n_workers: usize,
292 ) -> Result<Vec<DataPartition>> {
293 if n_estimators == 0 {
294 return Err(SklearsError::InvalidInput(
295 "Number of estimators must be greater than 0".to_string(),
296 ));
297 }
298
299 let mut partitions = Vec::new();
300 let estimators_per_worker = n_estimators / n_workers;
301 let remainder = n_estimators % n_workers;
302
303 let mut start_idx = 0;
304
305 for worker_id in 0..n_workers {
306 let current_size = estimators_per_worker + if worker_id < remainder { 1 } else { 0 };
307
308 if current_size > 0 {
309 let end_idx = start_idx + current_size;
310
311 partitions.push(DataPartition {
312 start_idx,
313 end_idx,
314 worker_id,
315 memory_estimate_mb: self.estimate_partition_memory(current_size),
316 });
317
318 start_idx = end_idx;
319 }
320 }
321
322 Ok(partitions)
323 }
324
325 fn estimate_partition_memory(&self, partition_size: usize) -> f64 {
327 let base_memory_mb = 10.0; let size_factor = partition_size as f64 * base_memory_mb;
330
331 if let Some(limit) = self.config.memory_limit_mb {
332 size_factor.min(limit as f64)
333 } else {
334 size_factor
335 }
336 }
337
338 fn calculate_efficiency_metrics(&mut self) {
340 let ideal_time =
341 self.performance_metrics.total_time_secs / self.performance_metrics.workers_used as f64;
342 let actual_time = self.performance_metrics.parallel_time_secs;
343
344 self.performance_metrics.parallel_efficiency = if actual_time > 0.0 {
345 ideal_time / actual_time
346 } else {
347 0.0
348 };
349
350 self.performance_metrics.load_balance_efficiency =
352 self.performance_metrics.parallel_efficiency * 0.9; }
354
355 pub fn performance_metrics(&self) -> &ParallelPerformanceMetrics {
357 &self.performance_metrics
358 }
359
360 pub fn reset_metrics(&mut self) {
362 self.performance_metrics = ParallelPerformanceMetrics::default();
363 }
364
365 pub fn configure_for_hardware(&mut self, n_cores: usize, memory_gb: usize) {
367 self.config.n_workers = Some(n_cores);
368 self.config.thread_pool_size = Some(n_cores);
369 self.config.memory_limit_mb = Some((memory_gb * 1024) / n_cores);
370
371 let estimated_batch_size = (memory_gb * 1024) / (n_cores * 100); self.config.batch_size = Some(estimated_batch_size);
374 }
375
376 pub fn enable_advanced_features(&mut self) {
378 self.config.work_stealing = true;
379 self.config.load_balancing = true;
380 self.config.strategy = ParallelStrategy::Hybrid;
381 }
382}
383
384pub struct AsyncEnsembleCoordinator {
386 config: ParallelConfig,
387 active_workers: Vec<usize>,
388 completed_tasks: Vec<usize>,
389}
390
391impl AsyncEnsembleCoordinator {
392 pub fn new(config: ParallelConfig) -> Self {
394 Self {
395 config,
396 active_workers: Vec::new(),
397 completed_tasks: Vec::new(),
398 }
399 }
400
401 pub fn submit_task(&mut self, worker_id: usize, task_id: usize) {
403 self.active_workers.push(worker_id);
404 }
406
407 pub fn wait_for_completion(&mut self) -> Result<Vec<usize>> {
409 Ok(self.completed_tasks.clone())
411 }
412
413 pub fn get_worker_status(&self) -> Vec<usize> {
415 self.active_workers.clone()
416 }
417}
418
419pub struct FederatedEnsembleCoordinator {
421 nodes: Vec<String>,
422 aggregation_strategy: String,
423 communication_protocol: String,
424}
425
426impl FederatedEnsembleCoordinator {
427 pub fn new(nodes: Vec<String>) -> Self {
429 Self {
430 nodes,
431 aggregation_strategy: "average".to_string(),
432 communication_protocol: "http".to_string(),
433 }
434 }
435
436 pub fn coordinate_training(&self) -> Result<()> {
438 Ok(())
445 }
446
447 pub fn set_aggregation_strategy(&mut self, strategy: &str) {
449 self.aggregation_strategy = strategy.to_string();
450 }
451}
452
453#[allow(non_snake_case)]
454#[cfg(test)]
455mod tests {
456 use super::*;
457 use scirs2_core::ndarray::{Array1, Array2};
458
459 struct MockTrainer;
461
462 impl ParallelTrainable<Array2<f64>, Array1<i32>> for MockTrainer {
463 type Output = Vec<f64>;
464
465 fn train_single(
466 &self,
467 _x: &Array2<f64>,
468 _y: &Array1<i32>,
469 worker_id: usize,
470 ) -> Result<Self::Output> {
471 Ok(vec![worker_id as f64])
473 }
474
475 fn combine_results(results: Vec<Self::Output>) -> Result<Self::Output> {
476 Ok(results.into_iter().flatten().collect())
477 }
478
479 fn estimate_memory_usage(&self, _data_size: usize) -> usize {
480 1024 }
482 }
483
484 #[test]
485 fn test_parallel_trainer_creation() {
486 let trainer = ParallelEnsembleTrainer::auto();
487 assert!(trainer.config.n_workers.is_some());
488 assert_eq!(trainer.config.strategy, ParallelStrategy::DataParallel);
489 }
490
491 #[test]
492 fn test_data_partitions() {
493 let config = ParallelConfig::default();
494 let trainer = ParallelEnsembleTrainer::new(config);
495
496 let partitions = trainer.create_data_partitions(10, 4).unwrap();
497 assert_eq!(partitions.len(), 4);
498
499 let total_estimators: usize = partitions.iter().map(|p| p.end_idx - p.start_idx).sum();
501 assert_eq!(total_estimators, 10);
502 }
503
504 #[test]
505 fn test_mock_parallel_training() {
506 let mut trainer = ParallelEnsembleTrainer::auto();
507 let mock_trainer = MockTrainer;
508
509 let x = Array2::zeros((100, 5));
510 let y = Array1::zeros(100);
511
512 let results = trainer
513 .train_data_parallel(&mock_trainer, &x, &y, 4)
514 .unwrap();
515 assert_eq!(results.len(), 4); }
517
518 #[test]
519 fn test_hardware_configuration() {
520 let mut trainer = ParallelEnsembleTrainer::auto();
521 trainer.configure_for_hardware(8, 16);
522
523 assert_eq!(trainer.config.n_workers, Some(8));
524 assert_eq!(trainer.config.thread_pool_size, Some(8));
525 assert!(trainer.config.memory_limit_mb.is_some());
526 }
527
528 #[test]
529 fn test_async_coordinator() {
530 let config = ParallelConfig::default();
531 let mut coordinator = AsyncEnsembleCoordinator::new(config);
532
533 coordinator.submit_task(0, 1);
534 coordinator.submit_task(1, 2);
535
536 assert_eq!(coordinator.get_worker_status().len(), 2);
537 }
538
539 #[test]
540 fn test_federated_coordinator() {
541 let nodes = vec!["node1".to_string(), "node2".to_string()];
542 let mut coordinator = FederatedEnsembleCoordinator::new(nodes);
543
544 coordinator.set_aggregation_strategy("weighted_average");
545 assert_eq!(coordinator.aggregation_strategy, "weighted_average");
546 }
547}