1use rayon::prelude::*;
8use scirs2_core::random::rngs::StdRng;
9use scirs2_core::random::Rng;
10use scirs2_core::random::SeedableRng;
11use sklears_core::types::Float;
12use std::collections::HashMap;
13use std::sync::{Arc, RwLock};
14use std::time::{Duration, Instant};
15
16#[derive(Debug, Clone)]
18pub enum ParallelStrategy {
19 ParallelGridSearch {
21 chunk_size: usize,
22
23 load_balancing: LoadBalancingStrategy,
24 },
25 ParallelRandomSearch {
27 batch_size: usize,
28
29 dynamic_batching: bool,
30 },
31 ParallelBayesianOptimization {
33 batch_size: usize,
34 acquisition_strategy: BatchAcquisitionStrategy,
35 synchronization: SynchronizationStrategy,
36 },
37 AsynchronousOptimization {
39 max_concurrent: usize,
40 result_polling_interval: Duration,
41 },
42 DistributedOptimization {
44 worker_nodes: Vec<String>,
45 communication_protocol: CommunicationProtocol,
46 },
47 MultiObjectiveParallel {
49 objectives: Vec<String>,
50 pareto_batch_size: usize,
51 },
52}
53
54#[derive(Debug, Clone)]
56pub enum LoadBalancingStrategy {
57 Static,
59 Dynamic { rebalance_threshold: Float },
61 WorkStealing,
63 PriorityBased { priority_function: String },
65}
66
67#[derive(Debug, Clone)]
69pub enum BatchAcquisitionStrategy {
70 ConstantLiar { liar_value: Float },
72 KrigingBeliever,
74 QExpectedImprovement,
76 LocalPenalization { penalization_factor: Float },
78 ThompsonSampling { n_samples: usize },
80}
81
82#[derive(Debug, Clone)]
84pub enum SynchronizationStrategy {
85 Synchronous,
87 Asynchronous,
89 Hybrid { sync_interval: usize },
91}
92
93#[derive(Debug, Clone)]
95pub enum CommunicationProtocol {
96 TCP { port: u16 },
98 MessageQueue { queue_name: String },
100 SharedFilesystem { path: String },
102 Custom { config: HashMap<String, String> },
104}
105
106#[derive(Debug, Clone)]
108pub struct ParallelOptimizationConfig {
109 pub strategy: ParallelStrategy,
110 pub max_workers: usize,
111 pub timeout_per_evaluation: Option<Duration>,
112 pub memory_limit_per_worker: Option<usize>,
113 pub error_handling: ErrorHandlingStrategy,
114 pub progress_reporting: ProgressReportingConfig,
115 pub resource_monitoring: bool,
116 pub random_state: Option<u64>,
117}
118
119#[derive(Debug, Clone)]
121pub enum ErrorHandlingStrategy {
122 FailFast,
124 SkipErrors,
126 RetryOnError {
128 max_retries: usize,
129
130 backoff_factor: Float,
131 },
132 FallbackEvaluation { fallback_score: Float },
134}
135
136#[derive(Debug, Clone)]
138pub struct ProgressReportingConfig {
139 pub enabled: bool,
140 pub update_interval: Duration,
141 pub detailed_metrics: bool,
142 pub export_intermediate_results: bool,
143}
144
145#[derive(Debug, Clone)]
147pub struct ParallelOptimizationResult {
148 pub best_hyperparameters: HashMap<String, Float>,
149 pub best_score: Float,
150 pub all_evaluations: Vec<EvaluationResult>,
151 pub optimization_statistics: OptimizationStatistics,
152 pub worker_statistics: Vec<WorkerStatistics>,
153 pub parallelization_efficiency: Float,
154 pub total_wall_time: Duration,
155 pub total_cpu_time: Duration,
156}
157
158#[derive(Debug, Clone)]
160pub struct EvaluationResult {
161 pub hyperparameters: HashMap<String, Float>,
162 pub score: Float,
163 pub evaluation_time: Duration,
164 pub worker_id: usize,
165 pub timestamp: Instant,
166 pub additional_metrics: HashMap<String, Float>,
167 pub error: Option<String>,
168}
169
170#[derive(Debug, Clone)]
172pub struct OptimizationStatistics {
173 pub total_evaluations: usize,
174 pub successful_evaluations: usize,
175 pub failed_evaluations: usize,
176 pub average_evaluation_time: Duration,
177 pub convergence_rate: Float,
178 pub resource_utilization: ResourceUtilization,
179}
180
181#[derive(Debug, Clone)]
183pub struct ResourceUtilization {
184 pub cpu_utilization: Float,
185 pub memory_utilization: Float,
186 pub network_utilization: Float,
187 pub idle_time_percentage: Float,
188}
189
190#[derive(Debug, Clone)]
192pub struct WorkerStatistics {
193 pub worker_id: usize,
194 pub evaluations_completed: usize,
195 pub total_computation_time: Duration,
196 pub idle_time: Duration,
197 pub errors_encountered: usize,
198 pub average_evaluation_time: Duration,
199}
200
201pub struct ParallelOptimizer {
203 config: ParallelOptimizationConfig,
204 shared_state: Arc<RwLock<SharedOptimizationState>>,
205 worker_pool: Option<rayon::ThreadPool>,
206}
207
208#[derive(Debug)]
210pub struct SharedOptimizationState {
211 pub evaluations: Vec<EvaluationResult>,
212 pub best_score: Float,
213 pub best_hyperparameters: HashMap<String, Float>,
214 pub pending_evaluations: Vec<HashMap<String, Float>>,
215 pub completed_count: usize,
216 pub gaussian_process_model: Option<SimplifiedGP>,
217}
218
219#[derive(Debug, Clone)]
221pub struct SimplifiedGP {
222 pub observations: Vec<(Vec<Float>, Float)>,
223 pub hyperparameters: GPHyperparams,
224 pub trained: bool,
225}
226
227#[derive(Debug, Clone)]
229pub struct GPHyperparams {
230 pub length_scale: Float,
231 pub signal_variance: Float,
232 pub noise_variance: Float,
233}
234
235impl Default for ParallelOptimizationConfig {
236 fn default() -> Self {
237 Self {
238 strategy: ParallelStrategy::ParallelRandomSearch {
239 batch_size: 4,
240 dynamic_batching: true,
241 },
242 max_workers: num_cpus::get(),
243 timeout_per_evaluation: Some(Duration::from_secs(300)),
244 memory_limit_per_worker: None,
245 error_handling: ErrorHandlingStrategy::SkipErrors,
246 progress_reporting: ProgressReportingConfig {
247 enabled: true,
248 update_interval: Duration::from_secs(10),
249 detailed_metrics: false,
250 export_intermediate_results: false,
251 },
252 resource_monitoring: true,
253 random_state: None,
254 }
255 }
256}
257
258impl ParallelOptimizer {
259 pub fn new(config: ParallelOptimizationConfig) -> Result<Self, Box<dyn std::error::Error>> {
261 let worker_pool = rayon::ThreadPoolBuilder::new()
263 .num_threads(config.max_workers)
264 .build()?;
265
266 let shared_state = Arc::new(RwLock::new(SharedOptimizationState {
267 evaluations: Vec::new(),
268 best_score: Float::NEG_INFINITY,
269 best_hyperparameters: HashMap::new(),
270 pending_evaluations: Vec::new(),
271 completed_count: 0,
272 gaussian_process_model: None,
273 }));
274
275 Ok(Self {
276 config,
277 shared_state,
278 worker_pool: Some(worker_pool),
279 })
280 }
281
282 pub fn optimize<F>(
284 &mut self,
285 evaluation_fn: F,
286 parameter_bounds: &[(Float, Float)],
287 max_evaluations: usize,
288 ) -> Result<ParallelOptimizationResult, Box<dyn std::error::Error>>
289 where
290 F: Fn(&HashMap<String, Float>) -> Result<Float, Box<dyn std::error::Error>>
291 + Send
292 + Sync
293 + 'static,
294 {
295 let _start_time = Instant::now();
296 let evaluation_fn = Arc::new(evaluation_fn);
297
298 match &self.config.strategy {
299 ParallelStrategy::ParallelGridSearch { .. } => {
300 self.parallel_grid_search(evaluation_fn, parameter_bounds, max_evaluations)
301 }
302 ParallelStrategy::ParallelRandomSearch { .. } => {
303 self.parallel_random_search(evaluation_fn, parameter_bounds, max_evaluations)
304 }
305 ParallelStrategy::ParallelBayesianOptimization { .. } => self
306 .parallel_bayesian_optimization(evaluation_fn, parameter_bounds, max_evaluations),
307 ParallelStrategy::AsynchronousOptimization { .. } => {
308 self.asynchronous_optimization(evaluation_fn, parameter_bounds, max_evaluations)
309 }
310 ParallelStrategy::DistributedOptimization { .. } => {
311 self.distributed_optimization(evaluation_fn, parameter_bounds, max_evaluations)
312 }
313 ParallelStrategy::MultiObjectiveParallel { .. } => self
314 .multi_objective_parallel_optimization(
315 evaluation_fn,
316 parameter_bounds,
317 max_evaluations,
318 ),
319 }
320 }
321
322 fn parallel_grid_search<F>(
324 &mut self,
325 evaluation_fn: Arc<F>,
326 parameter_bounds: &[(Float, Float)],
327 max_evaluations: usize,
328 ) -> Result<ParallelOptimizationResult, Box<dyn std::error::Error>>
329 where
330 F: Fn(&HashMap<String, Float>) -> Result<Float, Box<dyn std::error::Error>>
331 + Send
332 + Sync
333 + 'static,
334 {
335 let (chunk_size, _load_balancing) = match &self.config.strategy {
336 ParallelStrategy::ParallelGridSearch {
337 chunk_size,
338 load_balancing,
339 } => (*chunk_size, load_balancing),
340 _ => unreachable!(),
341 };
342
343 let grid_configs = self.generate_grid_configurations(parameter_bounds, max_evaluations)?;
345
346 let shared_state = self.shared_state.clone();
348 let worker_pool = self.worker_pool.as_ref().unwrap();
349
350 worker_pool.install(|| {
351 grid_configs
352 .par_chunks(chunk_size)
353 .enumerate()
354 .for_each(|(chunk_id, chunk)| {
355 for (config_id, config) in chunk.iter().enumerate() {
356 let worker_id = chunk_id * chunk_size + config_id;
357 let start_time = Instant::now();
358
359 match evaluation_fn(config) {
360 Ok(score) => {
361 let evaluation_time = start_time.elapsed();
362 let result = EvaluationResult {
363 hyperparameters: config.clone(),
364 score,
365 evaluation_time,
366 worker_id,
367 timestamp: start_time,
368 additional_metrics: HashMap::new(),
369 error: None,
370 };
371
372 if let Ok(mut state) = shared_state.write() {
374 state.evaluations.push(result);
375 state.completed_count += 1;
376
377 if score > state.best_score {
378 state.best_score = score;
379 state.best_hyperparameters = config.clone();
380 }
381 }
382 }
383 Err(e) => {
384 if matches!(
385 self.config.error_handling,
386 ErrorHandlingStrategy::FailFast
387 ) {
388 panic!("Evaluation failed: {}", e);
389 }
390
391 let evaluation_time = start_time.elapsed();
392 let result = EvaluationResult {
393 hyperparameters: config.clone(),
394 score: Float::NEG_INFINITY,
395 evaluation_time,
396 worker_id,
397 timestamp: start_time,
398 additional_metrics: HashMap::new(),
399 error: Some(e.to_string()),
400 };
401
402 if let Ok(mut state) = shared_state.write() {
403 state.evaluations.push(result);
404 state.completed_count += 1;
405 }
406 }
407 }
408 }
409 });
410 });
411
412 self.create_result()
413 }
414
415 fn parallel_random_search<F>(
417 &mut self,
418 evaluation_fn: Arc<F>,
419 parameter_bounds: &[(Float, Float)],
420 max_evaluations: usize,
421 ) -> Result<ParallelOptimizationResult, Box<dyn std::error::Error>>
422 where
423 F: Fn(&HashMap<String, Float>) -> Result<Float, Box<dyn std::error::Error>>
424 + Send
425 + Sync
426 + 'static,
427 {
428 let (batch_size, dynamic_batching) = match &self.config.strategy {
429 ParallelStrategy::ParallelRandomSearch {
430 batch_size,
431 dynamic_batching,
432 } => (*batch_size, *dynamic_batching),
433 _ => unreachable!(),
434 };
435
436 let shared_state = self.shared_state.clone();
437 let worker_pool = self.worker_pool.as_ref().unwrap();
438
439 let mut rng = match self.config.random_state {
440 Some(seed) => StdRng::seed_from_u64(seed),
441 None => {
442 use scirs2_core::random::thread_rng;
443 StdRng::from_rng(&mut thread_rng())
444 }
445 };
446
447 let mut evaluations_completed = 0;
448 let mut current_batch_size = batch_size;
449
450 while evaluations_completed < max_evaluations {
451 if dynamic_batching {
453 current_batch_size = self.calculate_dynamic_batch_size(batch_size)?;
454 }
455
456 let batch_configs: Vec<HashMap<String, Float>> = (0..current_batch_size)
458 .map(|_| self.sample_random_configuration(parameter_bounds, &mut rng))
459 .collect::<Result<Vec<_>, _>>()?;
460
461 worker_pool.install(|| {
463 batch_configs
464 .par_iter()
465 .enumerate()
466 .for_each(|(local_id, config)| {
467 let worker_id = evaluations_completed + local_id;
468 let start_time = Instant::now();
469
470 match evaluation_fn(config) {
471 Ok(score) => {
472 let evaluation_time = start_time.elapsed();
473 let result = EvaluationResult {
474 hyperparameters: config.clone(),
475 score,
476 evaluation_time,
477 worker_id,
478 timestamp: start_time,
479 additional_metrics: HashMap::new(),
480 error: None,
481 };
482
483 if let Ok(mut state) = shared_state.write() {
484 state.evaluations.push(result);
485 state.completed_count += 1;
486
487 if score > state.best_score {
488 state.best_score = score;
489 state.best_hyperparameters = config.clone();
490 }
491 }
492 }
493 Err(e) => {
494 if !matches!(
495 self.config.error_handling,
496 ErrorHandlingStrategy::FailFast
497 ) {
498 let evaluation_time = start_time.elapsed();
499 let result = EvaluationResult {
500 hyperparameters: config.clone(),
501 score: Float::NEG_INFINITY,
502 evaluation_time,
503 worker_id,
504 timestamp: start_time,
505 additional_metrics: HashMap::new(),
506 error: Some(e.to_string()),
507 };
508
509 if let Ok(mut state) = shared_state.write() {
510 state.evaluations.push(result);
511 state.completed_count += 1;
512 }
513 }
514 }
515 }
516 });
517 });
518
519 evaluations_completed += current_batch_size;
520 }
521
522 self.create_result()
523 }
524
525 fn parallel_bayesian_optimization<F>(
527 &mut self,
528 evaluation_fn: Arc<F>,
529 parameter_bounds: &[(Float, Float)],
530 max_evaluations: usize,
531 ) -> Result<ParallelOptimizationResult, Box<dyn std::error::Error>>
532 where
533 F: Fn(&HashMap<String, Float>) -> Result<Float, Box<dyn std::error::Error>>
534 + Send
535 + Sync
536 + 'static,
537 {
538 let (batch_size, acquisition_strategy, synchronization) = match &self.config.strategy {
539 ParallelStrategy::ParallelBayesianOptimization {
540 batch_size,
541 acquisition_strategy,
542 synchronization,
543 } => (
544 *batch_size,
545 acquisition_strategy.clone(),
546 synchronization.clone(),
547 ),
548 _ => unreachable!(),
549 };
550
551 let shared_state = self.shared_state.clone();
552
553 let initial_evaluations = batch_size.min(5);
555 self.parallel_random_search(evaluation_fn.clone(), parameter_bounds, initial_evaluations)?;
556
557 let mut evaluations_completed = initial_evaluations;
558
559 while evaluations_completed < max_evaluations {
560 self.update_gaussian_process_model()?;
562
563 let next_batch = self.generate_acquisition_batch(
565 &acquisition_strategy,
566 parameter_bounds,
567 batch_size,
568 )?;
569
570 let worker_pool = self.worker_pool.as_ref().unwrap();
572 worker_pool.install(|| {
573 next_batch
574 .par_iter()
575 .enumerate()
576 .for_each(|(local_id, config)| {
577 let worker_id = evaluations_completed + local_id;
578 let start_time = Instant::now();
579
580 match evaluation_fn(config) {
581 Ok(score) => {
582 let evaluation_time = start_time.elapsed();
583 let result = EvaluationResult {
584 hyperparameters: config.clone(),
585 score,
586 evaluation_time,
587 worker_id,
588 timestamp: start_time,
589 additional_metrics: HashMap::new(),
590 error: None,
591 };
592
593 if let Ok(mut state) = shared_state.write() {
594 state.evaluations.push(result);
595 state.completed_count += 1;
596
597 if score > state.best_score {
598 state.best_score = score;
599 state.best_hyperparameters = config.clone();
600 }
601 }
602 }
603 Err(e) => {
604 if !matches!(
605 self.config.error_handling,
606 ErrorHandlingStrategy::FailFast
607 ) {
608 let evaluation_time = start_time.elapsed();
609 let result = EvaluationResult {
610 hyperparameters: config.clone(),
611 score: Float::NEG_INFINITY,
612 evaluation_time,
613 worker_id,
614 timestamp: start_time,
615 additional_metrics: HashMap::new(),
616 error: Some(e.to_string()),
617 };
618
619 if let Ok(mut state) = shared_state.write() {
620 state.evaluations.push(result);
621 state.completed_count += 1;
622 }
623 }
624 }
625 }
626 });
627 });
628
629 evaluations_completed += batch_size;
630
631 match synchronization {
633 SynchronizationStrategy::Synchronous => {
634 }
637 SynchronizationStrategy::Asynchronous => {
638 break;
640 }
641 SynchronizationStrategy::Hybrid { sync_interval } => {
642 if evaluations_completed % sync_interval == 0 {
643 std::thread::sleep(Duration::from_millis(10));
645 }
646 }
647 }
648 }
649
650 self.create_result()
651 }
652
653 fn asynchronous_optimization<F>(
655 &mut self,
656 evaluation_fn: Arc<F>,
657 parameter_bounds: &[(Float, Float)],
658 max_evaluations: usize,
659 ) -> Result<ParallelOptimizationResult, Box<dyn std::error::Error>>
660 where
661 F: Fn(&HashMap<String, Float>) -> Result<Float, Box<dyn std::error::Error>>
662 + Send
663 + Sync
664 + 'static,
665 {
666 self.parallel_random_search(evaluation_fn, parameter_bounds, max_evaluations)
668 }
669
670 fn distributed_optimization<F>(
672 &mut self,
673 evaluation_fn: Arc<F>,
674 parameter_bounds: &[(Float, Float)],
675 max_evaluations: usize,
676 ) -> Result<ParallelOptimizationResult, Box<dyn std::error::Error>>
677 where
678 F: Fn(&HashMap<String, Float>) -> Result<Float, Box<dyn std::error::Error>>
679 + Send
680 + Sync
681 + 'static,
682 {
683 self.parallel_random_search(evaluation_fn, parameter_bounds, max_evaluations)
686 }
687
688 fn multi_objective_parallel_optimization<F>(
690 &mut self,
691 evaluation_fn: Arc<F>,
692 parameter_bounds: &[(Float, Float)],
693 max_evaluations: usize,
694 ) -> Result<ParallelOptimizationResult, Box<dyn std::error::Error>>
695 where
696 F: Fn(&HashMap<String, Float>) -> Result<Float, Box<dyn std::error::Error>>
697 + Send
698 + Sync
699 + 'static,
700 {
701 self.parallel_random_search(evaluation_fn, parameter_bounds, max_evaluations)
703 }
704
705 fn generate_grid_configurations(
707 &self,
708 parameter_bounds: &[(Float, Float)],
709 max_evaluations: usize,
710 ) -> Result<Vec<HashMap<String, Float>>, Box<dyn std::error::Error>> {
711 let n_params = parameter_bounds.len();
712 let n_values_per_param = (max_evaluations as Float)
713 .powf(1.0 / n_params as Float)
714 .ceil() as usize;
715
716 let mut configurations = Vec::new();
717 let mut indices = vec![0; n_params];
718
719 loop {
720 let mut config = HashMap::new();
721 for (i, &(low, high)) in parameter_bounds.iter().enumerate() {
722 let value =
723 low + (high - low) * (indices[i] as Float) / (n_values_per_param - 1) as Float;
724 config.insert(format!("param_{}", i), value);
725 }
726 configurations.push(config);
727
728 let mut carry = 1;
730 for i in 0..n_params {
731 indices[i] += carry;
732 if indices[i] < n_values_per_param {
733 carry = 0;
734 break;
735 } else {
736 indices[i] = 0;
737 }
738 }
739
740 if carry == 1 || configurations.len() >= max_evaluations {
741 break;
742 }
743 }
744
745 Ok(configurations)
746 }
747
748 fn sample_random_configuration(
750 &self,
751 parameter_bounds: &[(Float, Float)],
752 rng: &mut StdRng,
753 ) -> Result<HashMap<String, Float>, Box<dyn std::error::Error>> {
754 let mut config = HashMap::new();
755
756 for (i, &(low, high)) in parameter_bounds.iter().enumerate() {
757 let value = rng.gen_range(low..high + 1.0);
758 config.insert(format!("param_{}", i), value);
759 }
760
761 Ok(config)
762 }
763
764 fn calculate_dynamic_batch_size(
766 &self,
767 base_batch_size: usize,
768 ) -> Result<usize, Box<dyn std::error::Error>> {
769 if let Ok(state) = self.shared_state.read() {
771 if state.evaluations.len() >= 10 {
772 let recent_evaluations = &state.evaluations[state.evaluations.len() - 10..];
773 let avg_time = recent_evaluations
774 .iter()
775 .map(|e| e.evaluation_time.as_secs_f64())
776 .sum::<f64>()
777 / recent_evaluations.len() as f64;
778
779 if avg_time < 1.0 {
781 Ok(base_batch_size * 2) } else if avg_time > 10.0 {
783 Ok(base_batch_size / 2) } else {
785 Ok(base_batch_size)
786 }
787 } else {
788 Ok(base_batch_size)
789 }
790 } else {
791 Ok(base_batch_size)
792 }
793 }
794
795 fn update_gaussian_process_model(&mut self) -> Result<(), Box<dyn std::error::Error>> {
797 if let Ok(mut state) = self.shared_state.write() {
798 let observations: Vec<(Vec<Float>, Float)> = state
799 .evaluations
800 .iter()
801 .filter(|e| e.error.is_none())
802 .map(|e| {
803 let params: Vec<Float> = e.hyperparameters.values().cloned().collect();
804 (params, e.score)
805 })
806 .collect();
807
808 if observations.len() >= 3 {
809 let gp = SimplifiedGP {
810 observations,
811 hyperparameters: GPHyperparams {
812 length_scale: 1.0,
813 signal_variance: 1.0,
814 noise_variance: 0.1,
815 },
816 trained: true,
817 };
818 state.gaussian_process_model = Some(gp);
819 }
820 }
821 Ok(())
822 }
823
824 fn generate_acquisition_batch(
826 &self,
827 _acquisition_strategy: &BatchAcquisitionStrategy,
828 parameter_bounds: &[(Float, Float)],
829 batch_size: usize,
830 ) -> Result<Vec<HashMap<String, Float>>, Box<dyn std::error::Error>> {
831 let mut rng = StdRng::seed_from_u64(42); let mut batch = Vec::new();
833
834 for _ in 0..batch_size {
835 batch.push(self.sample_random_configuration(parameter_bounds, &mut rng)?);
838 }
839
840 Ok(batch)
841 }
842
843 fn create_result(&self) -> Result<ParallelOptimizationResult, Box<dyn std::error::Error>> {
845 let state = self.shared_state.read().unwrap();
846
847 let successful_evaluations = state
848 .evaluations
849 .iter()
850 .filter(|e| e.error.is_none())
851 .count();
852
853 let failed_evaluations = state.evaluations.len() - successful_evaluations;
854
855 let total_evaluation_time: Duration =
856 state.evaluations.iter().map(|e| e.evaluation_time).sum();
857
858 let average_evaluation_time = if state.evaluations.is_empty() {
859 Duration::from_secs(0)
860 } else {
861 total_evaluation_time / state.evaluations.len() as u32
862 };
863
864 let mut worker_stats = HashMap::new();
866 for eval in &state.evaluations {
867 let stats = worker_stats
868 .entry(eval.worker_id)
869 .or_insert(WorkerStatistics {
870 worker_id: eval.worker_id,
871 evaluations_completed: 0,
872 total_computation_time: Duration::from_secs(0),
873 idle_time: Duration::from_secs(0),
874 errors_encountered: 0,
875 average_evaluation_time: Duration::from_secs(0),
876 });
877
878 stats.evaluations_completed += 1;
879 stats.total_computation_time += eval.evaluation_time;
880 if eval.error.is_some() {
881 stats.errors_encountered += 1;
882 }
883 }
884
885 for stats in worker_stats.values_mut() {
886 if stats.evaluations_completed > 0 {
887 stats.average_evaluation_time =
888 stats.total_computation_time / stats.evaluations_completed as u32;
889 }
890 }
891
892 Ok(ParallelOptimizationResult {
893 best_hyperparameters: state.best_hyperparameters.clone(),
894 best_score: state.best_score,
895 all_evaluations: state.evaluations.clone(),
896 optimization_statistics: OptimizationStatistics {
897 total_evaluations: state.evaluations.len(),
898 successful_evaluations,
899 failed_evaluations,
900 average_evaluation_time,
901 convergence_rate: 0.1, resource_utilization: ResourceUtilization {
903 cpu_utilization: 0.8,
904 memory_utilization: 0.6,
905 network_utilization: 0.1,
906 idle_time_percentage: 0.1,
907 },
908 },
909 worker_statistics: worker_stats.into_values().collect(),
910 parallelization_efficiency: successful_evaluations as Float
911 / self.config.max_workers as Float,
912 total_wall_time: total_evaluation_time,
913 total_cpu_time: total_evaluation_time * self.config.max_workers as u32,
914 })
915 }
916}
917
918pub fn parallel_optimize<F>(
920 evaluation_fn: F,
921 parameter_bounds: &[(Float, Float)],
922 max_evaluations: usize,
923 config: Option<ParallelOptimizationConfig>,
924) -> Result<ParallelOptimizationResult, Box<dyn std::error::Error>>
925where
926 F: Fn(&HashMap<String, Float>) -> Result<Float, Box<dyn std::error::Error>>
927 + Send
928 + Sync
929 + 'static,
930{
931 let config = config.unwrap_or_default();
932 let mut optimizer = ParallelOptimizer::new(config)?;
933 optimizer.optimize(evaluation_fn, parameter_bounds, max_evaluations)
934}
935
936#[allow(non_snake_case)]
937#[cfg(test)]
938mod tests {
939 use super::*;
940
941 fn mock_evaluation_function(
942 hyperparameters: &HashMap<String, Float>,
943 ) -> Result<Float, Box<dyn std::error::Error>> {
944 let score = hyperparameters
946 .values()
947 .map(|&x| -(x - 0.5).powi(2))
948 .sum::<Float>();
949 Ok(score)
950 }
951
952 #[test]
953 fn test_parallel_optimizer_creation() {
954 let config = ParallelOptimizationConfig::default();
955 let optimizer = ParallelOptimizer::new(config);
956 assert!(optimizer.is_ok());
957 }
958
959 #[test]
960 fn test_parallel_random_search() {
961 let config = ParallelOptimizationConfig {
962 strategy: ParallelStrategy::ParallelRandomSearch {
963 batch_size: 4,
964 dynamic_batching: false,
965 },
966 max_workers: 2,
967 ..Default::default()
968 };
969
970 let parameter_bounds = vec![(0.0, 1.0), (0.0, 1.0)];
971
972 let result = parallel_optimize(
973 mock_evaluation_function,
974 ¶meter_bounds,
975 10,
976 Some(config),
977 )
978 .unwrap();
979
980 assert!(result.best_score <= 0.0); assert!(result.optimization_statistics.total_evaluations >= 10);
983 assert!(result.optimization_statistics.total_evaluations <= 16); assert!(!result.worker_statistics.is_empty());
985 }
986
987 #[test]
988 fn test_parallel_grid_search() {
989 let config = ParallelOptimizationConfig {
990 strategy: ParallelStrategy::ParallelGridSearch {
991 chunk_size: 2,
992 load_balancing: LoadBalancingStrategy::Static,
993 },
994 max_workers: 2,
995 ..Default::default()
996 };
997
998 let parameter_bounds = vec![(0.0, 1.0), (0.0, 1.0)];
999
1000 let result = parallel_optimize(
1001 mock_evaluation_function,
1002 ¶meter_bounds,
1003 9, Some(config),
1005 )
1006 .unwrap();
1007
1008 assert!(result.best_score <= 0.0);
1009 assert!(result.optimization_statistics.total_evaluations > 0);
1010 }
1011
1012 #[test]
1013 fn test_error_handling() {
1014 let failing_function =
1015 |_: &HashMap<String, Float>| -> Result<Float, Box<dyn std::error::Error>> {
1016 Err("Test error".into())
1017 };
1018
1019 let config = ParallelOptimizationConfig {
1020 error_handling: ErrorHandlingStrategy::SkipErrors,
1021 max_workers: 2,
1022 ..Default::default()
1023 };
1024
1025 let parameter_bounds = vec![(0.0, 1.0)];
1026
1027 let result =
1028 parallel_optimize(failing_function, ¶meter_bounds, 5, Some(config)).unwrap();
1029
1030 assert!(result.optimization_statistics.failed_evaluations >= 5);
1032 assert_eq!(result.optimization_statistics.successful_evaluations, 0);
1033 assert_eq!(
1034 result.optimization_statistics.total_evaluations,
1035 result.optimization_statistics.failed_evaluations
1036 );
1037 }
1038}