1use rayon::prelude::*;
8use scirs2_core::random::rngs::StdRng;
9use scirs2_core::random::SeedableRng;
10use scirs2_core::RngExt;
11use sklears_core::types::Float;
12use std::collections::HashMap;
13use std::sync::{Arc, RwLock};
14use std::time::{Duration, Instant};
15
16#[derive(Debug, Clone)]
18pub enum ParallelStrategy {
19 ParallelGridSearch {
21 chunk_size: usize,
22
23 load_balancing: LoadBalancingStrategy,
24 },
25 ParallelRandomSearch {
27 batch_size: usize,
28
29 dynamic_batching: bool,
30 },
31 ParallelBayesianOptimization {
33 batch_size: usize,
34 acquisition_strategy: BatchAcquisitionStrategy,
35 synchronization: SynchronizationStrategy,
36 },
37 AsynchronousOptimization {
39 max_concurrent: usize,
40 result_polling_interval: Duration,
41 },
42 DistributedOptimization {
44 worker_nodes: Vec<String>,
45 communication_protocol: CommunicationProtocol,
46 },
47 MultiObjectiveParallel {
49 objectives: Vec<String>,
50 pareto_batch_size: usize,
51 },
52}
53
54#[derive(Debug, Clone)]
56pub enum LoadBalancingStrategy {
57 Static,
59 Dynamic { rebalance_threshold: Float },
61 WorkStealing,
63 PriorityBased { priority_function: String },
65}
66
67#[derive(Debug, Clone)]
69pub enum BatchAcquisitionStrategy {
70 ConstantLiar { liar_value: Float },
72 KrigingBeliever,
74 QExpectedImprovement,
76 LocalPenalization { penalization_factor: Float },
78 ThompsonSampling { n_samples: usize },
80}
81
82#[derive(Debug, Clone)]
84pub enum SynchronizationStrategy {
85 Synchronous,
87 Asynchronous,
89 Hybrid { sync_interval: usize },
91}
92
93#[derive(Debug, Clone)]
95pub enum CommunicationProtocol {
96 TCP { port: u16 },
98 MessageQueue { queue_name: String },
100 SharedFilesystem { path: String },
102 Custom { config: HashMap<String, String> },
104}
105
106#[derive(Debug, Clone)]
108pub struct ParallelOptimizationConfig {
109 pub strategy: ParallelStrategy,
110 pub max_workers: usize,
111 pub timeout_per_evaluation: Option<Duration>,
112 pub memory_limit_per_worker: Option<usize>,
113 pub error_handling: ErrorHandlingStrategy,
114 pub progress_reporting: ProgressReportingConfig,
115 pub resource_monitoring: bool,
116 pub random_state: Option<u64>,
117}
118
119#[derive(Debug, Clone)]
121pub enum ErrorHandlingStrategy {
122 FailFast,
124 SkipErrors,
126 RetryOnError {
128 max_retries: usize,
129
130 backoff_factor: Float,
131 },
132 FallbackEvaluation { fallback_score: Float },
134}
135
136#[derive(Debug, Clone)]
138pub struct ProgressReportingConfig {
139 pub enabled: bool,
140 pub update_interval: Duration,
141 pub detailed_metrics: bool,
142 pub export_intermediate_results: bool,
143}
144
145#[derive(Debug, Clone)]
147pub struct ParallelOptimizationResult {
148 pub best_hyperparameters: HashMap<String, Float>,
149 pub best_score: Float,
150 pub all_evaluations: Vec<EvaluationResult>,
151 pub optimization_statistics: OptimizationStatistics,
152 pub worker_statistics: Vec<WorkerStatistics>,
153 pub parallelization_efficiency: Float,
154 pub total_wall_time: Duration,
155 pub total_cpu_time: Duration,
156}
157
158#[derive(Debug, Clone)]
160pub struct EvaluationResult {
161 pub hyperparameters: HashMap<String, Float>,
162 pub score: Float,
163 pub evaluation_time: Duration,
164 pub worker_id: usize,
165 pub timestamp: Instant,
166 pub additional_metrics: HashMap<String, Float>,
167 pub error: Option<String>,
168}
169
170#[derive(Debug, Clone)]
172pub struct OptimizationStatistics {
173 pub total_evaluations: usize,
174 pub successful_evaluations: usize,
175 pub failed_evaluations: usize,
176 pub average_evaluation_time: Duration,
177 pub convergence_rate: Float,
178 pub resource_utilization: ResourceUtilization,
179}
180
181#[derive(Debug, Clone)]
183pub struct ResourceUtilization {
184 pub cpu_utilization: Float,
185 pub memory_utilization: Float,
186 pub network_utilization: Float,
187 pub idle_time_percentage: Float,
188}
189
190#[derive(Debug, Clone)]
192pub struct WorkerStatistics {
193 pub worker_id: usize,
194 pub evaluations_completed: usize,
195 pub total_computation_time: Duration,
196 pub idle_time: Duration,
197 pub errors_encountered: usize,
198 pub average_evaluation_time: Duration,
199}
200
201pub struct ParallelOptimizer {
203 config: ParallelOptimizationConfig,
204 shared_state: Arc<RwLock<SharedOptimizationState>>,
205 worker_pool: Option<rayon::ThreadPool>,
206}
207
208#[derive(Debug)]
210pub struct SharedOptimizationState {
211 pub evaluations: Vec<EvaluationResult>,
212 pub best_score: Float,
213 pub best_hyperparameters: HashMap<String, Float>,
214 pub pending_evaluations: Vec<HashMap<String, Float>>,
215 pub completed_count: usize,
216 pub gaussian_process_model: Option<SimplifiedGP>,
217}
218
219#[derive(Debug, Clone)]
221pub struct SimplifiedGP {
222 pub observations: Vec<(Vec<Float>, Float)>,
223 pub hyperparameters: GPHyperparams,
224 pub trained: bool,
225}
226
227#[derive(Debug, Clone)]
229pub struct GPHyperparams {
230 pub length_scale: Float,
231 pub signal_variance: Float,
232 pub noise_variance: Float,
233}
234
235impl Default for ParallelOptimizationConfig {
236 fn default() -> Self {
237 Self {
238 strategy: ParallelStrategy::ParallelRandomSearch {
239 batch_size: 4,
240 dynamic_batching: true,
241 },
242 max_workers: num_cpus::get(),
243 timeout_per_evaluation: Some(Duration::from_secs(300)),
244 memory_limit_per_worker: None,
245 error_handling: ErrorHandlingStrategy::SkipErrors,
246 progress_reporting: ProgressReportingConfig {
247 enabled: true,
248 update_interval: Duration::from_secs(10),
249 detailed_metrics: false,
250 export_intermediate_results: false,
251 },
252 resource_monitoring: true,
253 random_state: None,
254 }
255 }
256}
257
258impl ParallelOptimizer {
259 pub fn new(config: ParallelOptimizationConfig) -> Result<Self, Box<dyn std::error::Error>> {
261 let worker_pool = rayon::ThreadPoolBuilder::new()
263 .num_threads(config.max_workers)
264 .build()?;
265
266 let shared_state = Arc::new(RwLock::new(SharedOptimizationState {
267 evaluations: Vec::new(),
268 best_score: Float::NEG_INFINITY,
269 best_hyperparameters: HashMap::new(),
270 pending_evaluations: Vec::new(),
271 completed_count: 0,
272 gaussian_process_model: None,
273 }));
274
275 Ok(Self {
276 config,
277 shared_state,
278 worker_pool: Some(worker_pool),
279 })
280 }
281
282 pub fn optimize<F>(
284 &mut self,
285 evaluation_fn: F,
286 parameter_bounds: &[(Float, Float)],
287 max_evaluations: usize,
288 ) -> Result<ParallelOptimizationResult, Box<dyn std::error::Error>>
289 where
290 F: Fn(&HashMap<String, Float>) -> Result<Float, Box<dyn std::error::Error>>
291 + Send
292 + Sync
293 + 'static,
294 {
295 let _start_time = Instant::now();
296 let evaluation_fn = Arc::new(evaluation_fn);
297
298 match &self.config.strategy {
299 ParallelStrategy::ParallelGridSearch { .. } => {
300 self.parallel_grid_search(evaluation_fn, parameter_bounds, max_evaluations)
301 }
302 ParallelStrategy::ParallelRandomSearch { .. } => {
303 self.parallel_random_search(evaluation_fn, parameter_bounds, max_evaluations)
304 }
305 ParallelStrategy::ParallelBayesianOptimization { .. } => self
306 .parallel_bayesian_optimization(evaluation_fn, parameter_bounds, max_evaluations),
307 ParallelStrategy::AsynchronousOptimization { .. } => {
308 self.asynchronous_optimization(evaluation_fn, parameter_bounds, max_evaluations)
309 }
310 ParallelStrategy::DistributedOptimization { .. } => {
311 self.distributed_optimization(evaluation_fn, parameter_bounds, max_evaluations)
312 }
313 ParallelStrategy::MultiObjectiveParallel { .. } => self
314 .multi_objective_parallel_optimization(
315 evaluation_fn,
316 parameter_bounds,
317 max_evaluations,
318 ),
319 }
320 }
321
322 fn parallel_grid_search<F>(
324 &mut self,
325 evaluation_fn: Arc<F>,
326 parameter_bounds: &[(Float, Float)],
327 max_evaluations: usize,
328 ) -> Result<ParallelOptimizationResult, Box<dyn std::error::Error>>
329 where
330 F: Fn(&HashMap<String, Float>) -> Result<Float, Box<dyn std::error::Error>>
331 + Send
332 + Sync
333 + 'static,
334 {
335 let (chunk_size, _load_balancing) = match &self.config.strategy {
336 ParallelStrategy::ParallelGridSearch {
337 chunk_size,
338 load_balancing,
339 } => (*chunk_size, load_balancing),
340 _ => unreachable!(),
341 };
342
343 let grid_configs = self.generate_grid_configurations(parameter_bounds, max_evaluations)?;
345
346 let shared_state = self.shared_state.clone();
348 let worker_pool = self.worker_pool.as_ref().expect("operation should succeed");
349
350 worker_pool.install(|| {
351 grid_configs
352 .par_chunks(chunk_size)
353 .enumerate()
354 .for_each(|(chunk_id, chunk)| {
355 for (config_id, config) in chunk.iter().enumerate() {
356 let worker_id = chunk_id * chunk_size + config_id;
357 let start_time = Instant::now();
358
359 match evaluation_fn(config) {
360 Ok(score) => {
361 let evaluation_time = start_time.elapsed();
362 let result = EvaluationResult {
363 hyperparameters: config.clone(),
364 score,
365 evaluation_time,
366 worker_id,
367 timestamp: start_time,
368 additional_metrics: HashMap::new(),
369 error: None,
370 };
371
372 if let Ok(mut state) = shared_state.write() {
374 state.evaluations.push(result);
375 state.completed_count += 1;
376
377 if score > state.best_score {
378 state.best_score = score;
379 state.best_hyperparameters = config.clone();
380 }
381 }
382 }
383 Err(e) => {
384 let evaluation_time = start_time.elapsed();
385 let result = EvaluationResult {
386 hyperparameters: config.clone(),
387 score: Float::NEG_INFINITY,
388 evaluation_time,
389 worker_id,
390 timestamp: start_time,
391 additional_metrics: HashMap::new(),
392 error: Some(e.to_string()),
393 };
394
395 if matches!(
396 self.config.error_handling,
397 ErrorHandlingStrategy::FailFast
398 ) {
399 if let Ok(mut state) = shared_state.write() {
401 state.evaluations.push(result);
402 state.completed_count += 1;
403 }
404 return;
405 }
406
407 if let Ok(mut state) = shared_state.write() {
408 state.evaluations.push(result);
409 state.completed_count += 1;
410 }
411 }
412 }
413 }
414 });
415 });
416
417 self.create_result()
418 }
419
420 fn parallel_random_search<F>(
422 &mut self,
423 evaluation_fn: Arc<F>,
424 parameter_bounds: &[(Float, Float)],
425 max_evaluations: usize,
426 ) -> Result<ParallelOptimizationResult, Box<dyn std::error::Error>>
427 where
428 F: Fn(&HashMap<String, Float>) -> Result<Float, Box<dyn std::error::Error>>
429 + Send
430 + Sync
431 + 'static,
432 {
433 let (batch_size, dynamic_batching) = match &self.config.strategy {
434 ParallelStrategy::ParallelRandomSearch {
435 batch_size,
436 dynamic_batching,
437 } => (*batch_size, *dynamic_batching),
438 _ => unreachable!(),
439 };
440
441 let shared_state = self.shared_state.clone();
442 let worker_pool = self.worker_pool.as_ref().expect("operation should succeed");
443
444 let mut rng = match self.config.random_state {
445 Some(seed) => StdRng::seed_from_u64(seed),
446 None => {
447 use scirs2_core::random::thread_rng;
448 StdRng::from_rng(&mut thread_rng())
449 }
450 };
451
452 let mut evaluations_completed = 0;
453 let mut current_batch_size = batch_size;
454
455 while evaluations_completed < max_evaluations {
456 if dynamic_batching {
458 current_batch_size = self.calculate_dynamic_batch_size(batch_size)?;
459 }
460
461 let batch_configs: Vec<HashMap<String, Float>> = (0..current_batch_size)
463 .map(|_| self.sample_random_configuration(parameter_bounds, &mut rng))
464 .collect::<Result<Vec<_>, _>>()?;
465
466 worker_pool.install(|| {
468 batch_configs
469 .par_iter()
470 .enumerate()
471 .for_each(|(local_id, config)| {
472 let worker_id = evaluations_completed + local_id;
473 let start_time = Instant::now();
474
475 match evaluation_fn(config) {
476 Ok(score) => {
477 let evaluation_time = start_time.elapsed();
478 let result = EvaluationResult {
479 hyperparameters: config.clone(),
480 score,
481 evaluation_time,
482 worker_id,
483 timestamp: start_time,
484 additional_metrics: HashMap::new(),
485 error: None,
486 };
487
488 if let Ok(mut state) = shared_state.write() {
489 state.evaluations.push(result);
490 state.completed_count += 1;
491
492 if score > state.best_score {
493 state.best_score = score;
494 state.best_hyperparameters = config.clone();
495 }
496 }
497 }
498 Err(e) => {
499 if !matches!(
500 self.config.error_handling,
501 ErrorHandlingStrategy::FailFast
502 ) {
503 let evaluation_time = start_time.elapsed();
504 let result = EvaluationResult {
505 hyperparameters: config.clone(),
506 score: Float::NEG_INFINITY,
507 evaluation_time,
508 worker_id,
509 timestamp: start_time,
510 additional_metrics: HashMap::new(),
511 error: Some(e.to_string()),
512 };
513
514 if let Ok(mut state) = shared_state.write() {
515 state.evaluations.push(result);
516 state.completed_count += 1;
517 }
518 }
519 }
520 }
521 });
522 });
523
524 evaluations_completed += current_batch_size;
525 }
526
527 self.create_result()
528 }
529
530 fn parallel_bayesian_optimization<F>(
532 &mut self,
533 evaluation_fn: Arc<F>,
534 parameter_bounds: &[(Float, Float)],
535 max_evaluations: usize,
536 ) -> Result<ParallelOptimizationResult, Box<dyn std::error::Error>>
537 where
538 F: Fn(&HashMap<String, Float>) -> Result<Float, Box<dyn std::error::Error>>
539 + Send
540 + Sync
541 + 'static,
542 {
543 let (batch_size, acquisition_strategy, synchronization) = match &self.config.strategy {
544 ParallelStrategy::ParallelBayesianOptimization {
545 batch_size,
546 acquisition_strategy,
547 synchronization,
548 } => (
549 *batch_size,
550 acquisition_strategy.clone(),
551 synchronization.clone(),
552 ),
553 _ => unreachable!(),
554 };
555
556 let shared_state = self.shared_state.clone();
557
558 let initial_evaluations = batch_size.min(5);
560 self.parallel_random_search(evaluation_fn.clone(), parameter_bounds, initial_evaluations)?;
561
562 let mut evaluations_completed = initial_evaluations;
563
564 while evaluations_completed < max_evaluations {
565 self.update_gaussian_process_model()?;
567
568 let next_batch = self.generate_acquisition_batch(
570 &acquisition_strategy,
571 parameter_bounds,
572 batch_size,
573 )?;
574
575 let worker_pool = self.worker_pool.as_ref().expect("operation should succeed");
577 worker_pool.install(|| {
578 next_batch
579 .par_iter()
580 .enumerate()
581 .for_each(|(local_id, config)| {
582 let worker_id = evaluations_completed + local_id;
583 let start_time = Instant::now();
584
585 match evaluation_fn(config) {
586 Ok(score) => {
587 let evaluation_time = start_time.elapsed();
588 let result = EvaluationResult {
589 hyperparameters: config.clone(),
590 score,
591 evaluation_time,
592 worker_id,
593 timestamp: start_time,
594 additional_metrics: HashMap::new(),
595 error: None,
596 };
597
598 if let Ok(mut state) = shared_state.write() {
599 state.evaluations.push(result);
600 state.completed_count += 1;
601
602 if score > state.best_score {
603 state.best_score = score;
604 state.best_hyperparameters = config.clone();
605 }
606 }
607 }
608 Err(e) => {
609 if !matches!(
610 self.config.error_handling,
611 ErrorHandlingStrategy::FailFast
612 ) {
613 let evaluation_time = start_time.elapsed();
614 let result = EvaluationResult {
615 hyperparameters: config.clone(),
616 score: Float::NEG_INFINITY,
617 evaluation_time,
618 worker_id,
619 timestamp: start_time,
620 additional_metrics: HashMap::new(),
621 error: Some(e.to_string()),
622 };
623
624 if let Ok(mut state) = shared_state.write() {
625 state.evaluations.push(result);
626 state.completed_count += 1;
627 }
628 }
629 }
630 }
631 });
632 });
633
634 evaluations_completed += batch_size;
635
636 match synchronization {
638 SynchronizationStrategy::Synchronous => {
639 }
642 SynchronizationStrategy::Asynchronous => {
643 break;
645 }
646 SynchronizationStrategy::Hybrid { sync_interval } => {
647 if evaluations_completed % sync_interval == 0 {
648 std::thread::sleep(Duration::from_millis(10));
650 }
651 }
652 }
653 }
654
655 self.create_result()
656 }
657
658 fn asynchronous_optimization<F>(
660 &mut self,
661 evaluation_fn: Arc<F>,
662 parameter_bounds: &[(Float, Float)],
663 max_evaluations: usize,
664 ) -> Result<ParallelOptimizationResult, Box<dyn std::error::Error>>
665 where
666 F: Fn(&HashMap<String, Float>) -> Result<Float, Box<dyn std::error::Error>>
667 + Send
668 + Sync
669 + 'static,
670 {
671 self.parallel_random_search(evaluation_fn, parameter_bounds, max_evaluations)
673 }
674
675 fn distributed_optimization<F>(
677 &mut self,
678 evaluation_fn: Arc<F>,
679 parameter_bounds: &[(Float, Float)],
680 max_evaluations: usize,
681 ) -> Result<ParallelOptimizationResult, Box<dyn std::error::Error>>
682 where
683 F: Fn(&HashMap<String, Float>) -> Result<Float, Box<dyn std::error::Error>>
684 + Send
685 + Sync
686 + 'static,
687 {
688 self.parallel_random_search(evaluation_fn, parameter_bounds, max_evaluations)
691 }
692
693 fn multi_objective_parallel_optimization<F>(
695 &mut self,
696 evaluation_fn: Arc<F>,
697 parameter_bounds: &[(Float, Float)],
698 max_evaluations: usize,
699 ) -> Result<ParallelOptimizationResult, Box<dyn std::error::Error>>
700 where
701 F: Fn(&HashMap<String, Float>) -> Result<Float, Box<dyn std::error::Error>>
702 + Send
703 + Sync
704 + 'static,
705 {
706 self.parallel_random_search(evaluation_fn, parameter_bounds, max_evaluations)
708 }
709
710 fn generate_grid_configurations(
712 &self,
713 parameter_bounds: &[(Float, Float)],
714 max_evaluations: usize,
715 ) -> Result<Vec<HashMap<String, Float>>, Box<dyn std::error::Error>> {
716 let n_params = parameter_bounds.len();
717 let n_values_per_param = (max_evaluations as Float)
718 .powf(1.0 / n_params as Float)
719 .ceil() as usize;
720
721 let mut configurations = Vec::new();
722 let mut indices = vec![0; n_params];
723
724 loop {
725 let mut config = HashMap::new();
726 for (i, &(low, high)) in parameter_bounds.iter().enumerate() {
727 let value =
728 low + (high - low) * (indices[i] as Float) / (n_values_per_param - 1) as Float;
729 config.insert(format!("param_{}", i), value);
730 }
731 configurations.push(config);
732
733 let mut carry = 1;
735 for i in 0..n_params {
736 indices[i] += carry;
737 if indices[i] < n_values_per_param {
738 carry = 0;
739 break;
740 } else {
741 indices[i] = 0;
742 }
743 }
744
745 if carry == 1 || configurations.len() >= max_evaluations {
746 break;
747 }
748 }
749
750 Ok(configurations)
751 }
752
753 fn sample_random_configuration(
755 &self,
756 parameter_bounds: &[(Float, Float)],
757 rng: &mut StdRng,
758 ) -> Result<HashMap<String, Float>, Box<dyn std::error::Error>> {
759 let mut config = HashMap::new();
760
761 for (i, &(low, high)) in parameter_bounds.iter().enumerate() {
762 let value = rng.random_range(low..high + 1.0);
763 config.insert(format!("param_{}", i), value);
764 }
765
766 Ok(config)
767 }
768
769 fn calculate_dynamic_batch_size(
771 &self,
772 base_batch_size: usize,
773 ) -> Result<usize, Box<dyn std::error::Error>> {
774 if let Ok(state) = self.shared_state.read() {
776 if state.evaluations.len() >= 10 {
777 let recent_evaluations = &state.evaluations[state.evaluations.len() - 10..];
778 let avg_time = recent_evaluations
779 .iter()
780 .map(|e| e.evaluation_time.as_secs_f64())
781 .sum::<f64>()
782 / recent_evaluations.len() as f64;
783
784 if avg_time < 1.0 {
786 Ok(base_batch_size * 2) } else if avg_time > 10.0 {
788 Ok(base_batch_size / 2) } else {
790 Ok(base_batch_size)
791 }
792 } else {
793 Ok(base_batch_size)
794 }
795 } else {
796 Ok(base_batch_size)
797 }
798 }
799
800 fn update_gaussian_process_model(&mut self) -> Result<(), Box<dyn std::error::Error>> {
802 if let Ok(mut state) = self.shared_state.write() {
803 let observations: Vec<(Vec<Float>, Float)> = state
804 .evaluations
805 .iter()
806 .filter(|e| e.error.is_none())
807 .map(|e| {
808 let params: Vec<Float> = e.hyperparameters.values().cloned().collect();
809 (params, e.score)
810 })
811 .collect();
812
813 if observations.len() >= 3 {
814 let gp = SimplifiedGP {
815 observations,
816 hyperparameters: GPHyperparams {
817 length_scale: 1.0,
818 signal_variance: 1.0,
819 noise_variance: 0.1,
820 },
821 trained: true,
822 };
823 state.gaussian_process_model = Some(gp);
824 }
825 }
826 Ok(())
827 }
828
829 fn generate_acquisition_batch(
831 &self,
832 _acquisition_strategy: &BatchAcquisitionStrategy,
833 parameter_bounds: &[(Float, Float)],
834 batch_size: usize,
835 ) -> Result<Vec<HashMap<String, Float>>, Box<dyn std::error::Error>> {
836 let mut rng = StdRng::seed_from_u64(42); let mut batch = Vec::new();
838
839 for _ in 0..batch_size {
840 batch.push(self.sample_random_configuration(parameter_bounds, &mut rng)?);
843 }
844
845 Ok(batch)
846 }
847
848 fn create_result(&self) -> Result<ParallelOptimizationResult, Box<dyn std::error::Error>> {
850 let state = self.shared_state.read().expect("operation should succeed");
851
852 let successful_evaluations = state
853 .evaluations
854 .iter()
855 .filter(|e| e.error.is_none())
856 .count();
857
858 let failed_evaluations = state.evaluations.len() - successful_evaluations;
859
860 let total_evaluation_time: Duration =
861 state.evaluations.iter().map(|e| e.evaluation_time).sum();
862
863 let average_evaluation_time = if state.evaluations.is_empty() {
864 Duration::from_secs(0)
865 } else {
866 total_evaluation_time / state.evaluations.len() as u32
867 };
868
869 let mut worker_stats = HashMap::new();
871 for eval in &state.evaluations {
872 let stats = worker_stats
873 .entry(eval.worker_id)
874 .or_insert(WorkerStatistics {
875 worker_id: eval.worker_id,
876 evaluations_completed: 0,
877 total_computation_time: Duration::from_secs(0),
878 idle_time: Duration::from_secs(0),
879 errors_encountered: 0,
880 average_evaluation_time: Duration::from_secs(0),
881 });
882
883 stats.evaluations_completed += 1;
884 stats.total_computation_time += eval.evaluation_time;
885 if eval.error.is_some() {
886 stats.errors_encountered += 1;
887 }
888 }
889
890 for stats in worker_stats.values_mut() {
891 if stats.evaluations_completed > 0 {
892 stats.average_evaluation_time =
893 stats.total_computation_time / stats.evaluations_completed as u32;
894 }
895 }
896
897 Ok(ParallelOptimizationResult {
898 best_hyperparameters: state.best_hyperparameters.clone(),
899 best_score: state.best_score,
900 all_evaluations: state.evaluations.clone(),
901 optimization_statistics: OptimizationStatistics {
902 total_evaluations: state.evaluations.len(),
903 successful_evaluations,
904 failed_evaluations,
905 average_evaluation_time,
906 convergence_rate: 0.1, resource_utilization: ResourceUtilization {
908 cpu_utilization: 0.8,
909 memory_utilization: 0.6,
910 network_utilization: 0.1,
911 idle_time_percentage: 0.1,
912 },
913 },
914 worker_statistics: worker_stats.into_values().collect(),
915 parallelization_efficiency: successful_evaluations as Float
916 / self.config.max_workers as Float,
917 total_wall_time: total_evaluation_time,
918 total_cpu_time: total_evaluation_time * self.config.max_workers as u32,
919 })
920 }
921}
922
923pub fn parallel_optimize<F>(
925 evaluation_fn: F,
926 parameter_bounds: &[(Float, Float)],
927 max_evaluations: usize,
928 config: Option<ParallelOptimizationConfig>,
929) -> Result<ParallelOptimizationResult, Box<dyn std::error::Error>>
930where
931 F: Fn(&HashMap<String, Float>) -> Result<Float, Box<dyn std::error::Error>>
932 + Send
933 + Sync
934 + 'static,
935{
936 let config = config.unwrap_or_default();
937 let mut optimizer = ParallelOptimizer::new(config)?;
938 optimizer.optimize(evaluation_fn, parameter_bounds, max_evaluations)
939}
940
941#[allow(non_snake_case)]
942#[cfg(test)]
943mod tests {
944 use super::*;
945
946 fn mock_evaluation_function(
947 hyperparameters: &HashMap<String, Float>,
948 ) -> Result<Float, Box<dyn std::error::Error>> {
949 let score = hyperparameters
951 .values()
952 .map(|&x| -(x - 0.5).powi(2))
953 .sum::<Float>();
954 Ok(score)
955 }
956
957 #[test]
958 fn test_parallel_optimizer_creation() {
959 let config = ParallelOptimizationConfig::default();
960 let optimizer = ParallelOptimizer::new(config);
961 assert!(optimizer.is_ok());
962 }
963
964 #[test]
965 fn test_parallel_random_search() {
966 let config = ParallelOptimizationConfig {
967 strategy: ParallelStrategy::ParallelRandomSearch {
968 batch_size: 4,
969 dynamic_batching: false,
970 },
971 max_workers: 2,
972 ..Default::default()
973 };
974
975 let parameter_bounds = vec![(0.0, 1.0), (0.0, 1.0)];
976
977 let result = parallel_optimize(
978 mock_evaluation_function,
979 ¶meter_bounds,
980 10,
981 Some(config),
982 )
983 .expect("operation should succeed");
984
985 assert!(result.best_score <= 0.0); assert!(result.optimization_statistics.total_evaluations >= 10);
988 assert!(result.optimization_statistics.total_evaluations <= 16); assert!(!result.worker_statistics.is_empty());
990 }
991
992 #[test]
993 fn test_parallel_grid_search() {
994 let config = ParallelOptimizationConfig {
995 strategy: ParallelStrategy::ParallelGridSearch {
996 chunk_size: 2,
997 load_balancing: LoadBalancingStrategy::Static,
998 },
999 max_workers: 2,
1000 ..Default::default()
1001 };
1002
1003 let parameter_bounds = vec![(0.0, 1.0), (0.0, 1.0)];
1004
1005 let result = parallel_optimize(
1006 mock_evaluation_function,
1007 ¶meter_bounds,
1008 9, Some(config),
1010 )
1011 .expect("operation should succeed");
1012
1013 assert!(result.best_score <= 0.0);
1014 assert!(result.optimization_statistics.total_evaluations > 0);
1015 }
1016
1017 #[test]
1018 fn test_error_handling() {
1019 let failing_function =
1020 |_: &HashMap<String, Float>| -> Result<Float, Box<dyn std::error::Error>> {
1021 Err("Test error".into())
1022 };
1023
1024 let config = ParallelOptimizationConfig {
1025 error_handling: ErrorHandlingStrategy::SkipErrors,
1026 max_workers: 2,
1027 ..Default::default()
1028 };
1029
1030 let parameter_bounds = vec![(0.0, 1.0)];
1031
1032 let result = parallel_optimize(failing_function, ¶meter_bounds, 5, Some(config))
1033 .expect("operation should succeed");
1034
1035 assert!(result.optimization_statistics.failed_evaluations >= 5);
1037 assert_eq!(result.optimization_statistics.successful_evaluations, 0);
1038 assert_eq!(
1039 result.optimization_statistics.total_evaluations,
1040 result.optimization_statistics.failed_evaluations
1041 );
1042 }
1043}