Skip to main content

sklears_model_selection/
parallel_optimization.rs

1//! Parallel Hyperparameter Search
2//!
3//! This module provides parallel hyperparameter optimization using rayon for concurrent evaluation
4//! of hyperparameter configurations. It supports various parallelization strategies and load balancing
5//! techniques to efficiently utilize available computational resources.
6
7use rayon::prelude::*;
8use scirs2_core::random::rngs::StdRng;
9use scirs2_core::random::SeedableRng;
10use scirs2_core::RngExt;
11use sklears_core::types::Float;
12use std::collections::HashMap;
13use std::sync::{Arc, RwLock};
14use std::time::{Duration, Instant};
15
16/// Parallel optimization strategies
17#[derive(Debug, Clone)]
18pub enum ParallelStrategy {
19    /// Simple parallel grid search
20    ParallelGridSearch {
21        chunk_size: usize,
22
23        load_balancing: LoadBalancingStrategy,
24    },
25    /// Parallel random search
26    ParallelRandomSearch {
27        batch_size: usize,
28
29        dynamic_batching: bool,
30    },
31    /// Parallel Bayesian optimization
32    ParallelBayesianOptimization {
33        batch_size: usize,
34        acquisition_strategy: BatchAcquisitionStrategy,
35        synchronization: SynchronizationStrategy,
36    },
37    /// Asynchronous optimization
38    AsynchronousOptimization {
39        max_concurrent: usize,
40        result_polling_interval: Duration,
41    },
42    /// Distributed optimization across multiple machines
43    DistributedOptimization {
44        worker_nodes: Vec<String>,
45        communication_protocol: CommunicationProtocol,
46    },
47    /// Multi-objective parallel optimization
48    MultiObjectiveParallel {
49        objectives: Vec<String>,
50        pareto_batch_size: usize,
51    },
52}
53
54/// Load balancing strategies for parallel execution
55#[derive(Debug, Clone)]
56pub enum LoadBalancingStrategy {
57    /// Static load balancing with equal chunks
58    Static,
59    /// Dynamic load balancing based on completion time
60    Dynamic { rebalance_threshold: Float },
61    /// Work-stealing approach
62    WorkStealing,
63    /// Priority-based load balancing
64    PriorityBased { priority_function: String },
65}
66
67/// Batch acquisition strategies for parallel Bayesian optimization
68#[derive(Debug, Clone)]
69pub enum BatchAcquisitionStrategy {
70    /// Constant Liar strategy
71    ConstantLiar { liar_value: Float },
72    /// Kriging Believer strategy
73    KrigingBeliever,
74    /// qExpected Improvement
75    QExpectedImprovement,
76    /// Local Penalization
77    LocalPenalization { penalization_factor: Float },
78    /// Thompson Sampling
79    ThompsonSampling { n_samples: usize },
80}
81
82/// Synchronization strategies for parallel optimization
83#[derive(Debug, Clone)]
84pub enum SynchronizationStrategy {
85    /// Synchronous updates (wait for all workers)
86    Synchronous,
87    /// Asynchronous updates (update as results arrive)
88    Asynchronous,
89    /// Hybrid approach with periodic synchronization
90    Hybrid { sync_interval: usize },
91}
92
93/// Communication protocols for distributed optimization
94#[derive(Debug, Clone)]
95pub enum CommunicationProtocol {
96    /// TCP-based communication
97    TCP { port: u16 },
98    /// Message queues
99    MessageQueue { queue_name: String },
100    /// Shared filesystem
101    SharedFilesystem { path: String },
102    /// Custom protocol
103    Custom { config: HashMap<String, String> },
104}
105
106/// Parallel optimization configuration
107#[derive(Debug, Clone)]
108pub struct ParallelOptimizationConfig {
109    pub strategy: ParallelStrategy,
110    pub max_workers: usize,
111    pub timeout_per_evaluation: Option<Duration>,
112    pub memory_limit_per_worker: Option<usize>,
113    pub error_handling: ErrorHandlingStrategy,
114    pub progress_reporting: ProgressReportingConfig,
115    pub resource_monitoring: bool,
116    pub random_state: Option<u64>,
117}
118
119/// Error handling strategies
120#[derive(Debug, Clone)]
121pub enum ErrorHandlingStrategy {
122    /// Fail fast on any error
123    FailFast,
124    /// Continue on errors, skip failed evaluations
125    SkipErrors,
126    /// Retry failed evaluations
127    RetryOnError {
128        max_retries: usize,
129
130        backoff_factor: Float,
131    },
132    /// Use fallback evaluations for failed ones
133    FallbackEvaluation { fallback_score: Float },
134}
135
136/// Progress reporting configuration
137#[derive(Debug, Clone)]
138pub struct ProgressReportingConfig {
139    pub enabled: bool,
140    pub update_interval: Duration,
141    pub detailed_metrics: bool,
142    pub export_intermediate_results: bool,
143}
144
145/// Parallel optimization result
146#[derive(Debug, Clone)]
147pub struct ParallelOptimizationResult {
148    pub best_hyperparameters: HashMap<String, Float>,
149    pub best_score: Float,
150    pub all_evaluations: Vec<EvaluationResult>,
151    pub optimization_statistics: OptimizationStatistics,
152    pub worker_statistics: Vec<WorkerStatistics>,
153    pub parallelization_efficiency: Float,
154    pub total_wall_time: Duration,
155    pub total_cpu_time: Duration,
156}
157
158/// Individual evaluation result
159#[derive(Debug, Clone)]
160pub struct EvaluationResult {
161    pub hyperparameters: HashMap<String, Float>,
162    pub score: Float,
163    pub evaluation_time: Duration,
164    pub worker_id: usize,
165    pub timestamp: Instant,
166    pub additional_metrics: HashMap<String, Float>,
167    pub error: Option<String>,
168}
169
170/// Optimization statistics
171#[derive(Debug, Clone)]
172pub struct OptimizationStatistics {
173    pub total_evaluations: usize,
174    pub successful_evaluations: usize,
175    pub failed_evaluations: usize,
176    pub average_evaluation_time: Duration,
177    pub convergence_rate: Float,
178    pub resource_utilization: ResourceUtilization,
179}
180
181/// Resource utilization metrics
182#[derive(Debug, Clone)]
183pub struct ResourceUtilization {
184    pub cpu_utilization: Float,
185    pub memory_utilization: Float,
186    pub network_utilization: Float,
187    pub idle_time_percentage: Float,
188}
189
190/// Worker-specific statistics
191#[derive(Debug, Clone)]
192pub struct WorkerStatistics {
193    pub worker_id: usize,
194    pub evaluations_completed: usize,
195    pub total_computation_time: Duration,
196    pub idle_time: Duration,
197    pub errors_encountered: usize,
198    pub average_evaluation_time: Duration,
199}
200
201/// Parallel hyperparameter optimizer
202pub struct ParallelOptimizer {
203    config: ParallelOptimizationConfig,
204    shared_state: Arc<RwLock<SharedOptimizationState>>,
205    worker_pool: Option<rayon::ThreadPool>,
206}
207
208/// Shared state between workers
209#[derive(Debug)]
210pub struct SharedOptimizationState {
211    pub evaluations: Vec<EvaluationResult>,
212    pub best_score: Float,
213    pub best_hyperparameters: HashMap<String, Float>,
214    pub pending_evaluations: Vec<HashMap<String, Float>>,
215    pub completed_count: usize,
216    pub gaussian_process_model: Option<SimplifiedGP>,
217}
218
219/// Simplified Gaussian Process for parallel optimization
220#[derive(Debug, Clone)]
221pub struct SimplifiedGP {
222    pub observations: Vec<(Vec<Float>, Float)>,
223    pub hyperparameters: GPHyperparams,
224    pub trained: bool,
225}
226
227/// GP hyperparameters
228#[derive(Debug, Clone)]
229pub struct GPHyperparams {
230    pub length_scale: Float,
231    pub signal_variance: Float,
232    pub noise_variance: Float,
233}
234
235impl Default for ParallelOptimizationConfig {
236    fn default() -> Self {
237        Self {
238            strategy: ParallelStrategy::ParallelRandomSearch {
239                batch_size: 4,
240                dynamic_batching: true,
241            },
242            max_workers: num_cpus::get(),
243            timeout_per_evaluation: Some(Duration::from_secs(300)),
244            memory_limit_per_worker: None,
245            error_handling: ErrorHandlingStrategy::SkipErrors,
246            progress_reporting: ProgressReportingConfig {
247                enabled: true,
248                update_interval: Duration::from_secs(10),
249                detailed_metrics: false,
250                export_intermediate_results: false,
251            },
252            resource_monitoring: true,
253            random_state: None,
254        }
255    }
256}
257
258impl ParallelOptimizer {
259    /// Create a new parallel optimizer
260    pub fn new(config: ParallelOptimizationConfig) -> Result<Self, Box<dyn std::error::Error>> {
261        // Create custom thread pool
262        let worker_pool = rayon::ThreadPoolBuilder::new()
263            .num_threads(config.max_workers)
264            .build()?;
265
266        let shared_state = Arc::new(RwLock::new(SharedOptimizationState {
267            evaluations: Vec::new(),
268            best_score: Float::NEG_INFINITY,
269            best_hyperparameters: HashMap::new(),
270            pending_evaluations: Vec::new(),
271            completed_count: 0,
272            gaussian_process_model: None,
273        }));
274
275        Ok(Self {
276            config,
277            shared_state,
278            worker_pool: Some(worker_pool),
279        })
280    }
281
282    /// Optimize hyperparameters in parallel
283    pub fn optimize<F>(
284        &mut self,
285        evaluation_fn: F,
286        parameter_bounds: &[(Float, Float)],
287        max_evaluations: usize,
288    ) -> Result<ParallelOptimizationResult, Box<dyn std::error::Error>>
289    where
290        F: Fn(&HashMap<String, Float>) -> Result<Float, Box<dyn std::error::Error>>
291            + Send
292            + Sync
293            + 'static,
294    {
295        let _start_time = Instant::now();
296        let evaluation_fn = Arc::new(evaluation_fn);
297
298        match &self.config.strategy {
299            ParallelStrategy::ParallelGridSearch { .. } => {
300                self.parallel_grid_search(evaluation_fn, parameter_bounds, max_evaluations)
301            }
302            ParallelStrategy::ParallelRandomSearch { .. } => {
303                self.parallel_random_search(evaluation_fn, parameter_bounds, max_evaluations)
304            }
305            ParallelStrategy::ParallelBayesianOptimization { .. } => self
306                .parallel_bayesian_optimization(evaluation_fn, parameter_bounds, max_evaluations),
307            ParallelStrategy::AsynchronousOptimization { .. } => {
308                self.asynchronous_optimization(evaluation_fn, parameter_bounds, max_evaluations)
309            }
310            ParallelStrategy::DistributedOptimization { .. } => {
311                self.distributed_optimization(evaluation_fn, parameter_bounds, max_evaluations)
312            }
313            ParallelStrategy::MultiObjectiveParallel { .. } => self
314                .multi_objective_parallel_optimization(
315                    evaluation_fn,
316                    parameter_bounds,
317                    max_evaluations,
318                ),
319        }
320    }
321
322    /// Parallel grid search implementation
323    fn parallel_grid_search<F>(
324        &mut self,
325        evaluation_fn: Arc<F>,
326        parameter_bounds: &[(Float, Float)],
327        max_evaluations: usize,
328    ) -> Result<ParallelOptimizationResult, Box<dyn std::error::Error>>
329    where
330        F: Fn(&HashMap<String, Float>) -> Result<Float, Box<dyn std::error::Error>>
331            + Send
332            + Sync
333            + 'static,
334    {
335        let (chunk_size, _load_balancing) = match &self.config.strategy {
336            ParallelStrategy::ParallelGridSearch {
337                chunk_size,
338                load_balancing,
339            } => (*chunk_size, load_balancing),
340            _ => unreachable!(),
341        };
342
343        // Generate grid configurations
344        let grid_configs = self.generate_grid_configurations(parameter_bounds, max_evaluations)?;
345
346        // Process configurations in parallel chunks
347        let shared_state = self.shared_state.clone();
348        let worker_pool = self.worker_pool.as_ref().expect("operation should succeed");
349
350        worker_pool.install(|| {
351            grid_configs
352                .par_chunks(chunk_size)
353                .enumerate()
354                .for_each(|(chunk_id, chunk)| {
355                    for (config_id, config) in chunk.iter().enumerate() {
356                        let worker_id = chunk_id * chunk_size + config_id;
357                        let start_time = Instant::now();
358
359                        match evaluation_fn(config) {
360                            Ok(score) => {
361                                let evaluation_time = start_time.elapsed();
362                                let result = EvaluationResult {
363                                    hyperparameters: config.clone(),
364                                    score,
365                                    evaluation_time,
366                                    worker_id,
367                                    timestamp: start_time,
368                                    additional_metrics: HashMap::new(),
369                                    error: None,
370                                };
371
372                                // Update shared state
373                                if let Ok(mut state) = shared_state.write() {
374                                    state.evaluations.push(result);
375                                    state.completed_count += 1;
376
377                                    if score > state.best_score {
378                                        state.best_score = score;
379                                        state.best_hyperparameters = config.clone();
380                                    }
381                                }
382                            }
383                            Err(e) => {
384                                let evaluation_time = start_time.elapsed();
385                                let result = EvaluationResult {
386                                    hyperparameters: config.clone(),
387                                    score: Float::NEG_INFINITY,
388                                    evaluation_time,
389                                    worker_id,
390                                    timestamp: start_time,
391                                    additional_metrics: HashMap::new(),
392                                    error: Some(e.to_string()),
393                                };
394
395                                if matches!(
396                                    self.config.error_handling,
397                                    ErrorHandlingStrategy::FailFast
398                                ) {
399                                    // Record the error and return early instead of panicking
400                                    if let Ok(mut state) = shared_state.write() {
401                                        state.evaluations.push(result);
402                                        state.completed_count += 1;
403                                    }
404                                    return;
405                                }
406
407                                if let Ok(mut state) = shared_state.write() {
408                                    state.evaluations.push(result);
409                                    state.completed_count += 1;
410                                }
411                            }
412                        }
413                    }
414                });
415        });
416
417        self.create_result()
418    }
419
420    /// Parallel random search implementation
421    fn parallel_random_search<F>(
422        &mut self,
423        evaluation_fn: Arc<F>,
424        parameter_bounds: &[(Float, Float)],
425        max_evaluations: usize,
426    ) -> Result<ParallelOptimizationResult, Box<dyn std::error::Error>>
427    where
428        F: Fn(&HashMap<String, Float>) -> Result<Float, Box<dyn std::error::Error>>
429            + Send
430            + Sync
431            + 'static,
432    {
433        let (batch_size, dynamic_batching) = match &self.config.strategy {
434            ParallelStrategy::ParallelRandomSearch {
435                batch_size,
436                dynamic_batching,
437            } => (*batch_size, *dynamic_batching),
438            _ => unreachable!(),
439        };
440
441        let shared_state = self.shared_state.clone();
442        let worker_pool = self.worker_pool.as_ref().expect("operation should succeed");
443
444        let mut rng = match self.config.random_state {
445            Some(seed) => StdRng::seed_from_u64(seed),
446            None => {
447                use scirs2_core::random::thread_rng;
448                StdRng::from_rng(&mut thread_rng())
449            }
450        };
451
452        let mut evaluations_completed = 0;
453        let mut current_batch_size = batch_size;
454
455        while evaluations_completed < max_evaluations {
456            // Adjust batch size dynamically if enabled
457            if dynamic_batching {
458                current_batch_size = self.calculate_dynamic_batch_size(batch_size)?;
459            }
460
461            // Generate batch of random configurations
462            let batch_configs: Vec<HashMap<String, Float>> = (0..current_batch_size)
463                .map(|_| self.sample_random_configuration(parameter_bounds, &mut rng))
464                .collect::<Result<Vec<_>, _>>()?;
465
466            // Evaluate batch in parallel
467            worker_pool.install(|| {
468                batch_configs
469                    .par_iter()
470                    .enumerate()
471                    .for_each(|(local_id, config)| {
472                        let worker_id = evaluations_completed + local_id;
473                        let start_time = Instant::now();
474
475                        match evaluation_fn(config) {
476                            Ok(score) => {
477                                let evaluation_time = start_time.elapsed();
478                                let result = EvaluationResult {
479                                    hyperparameters: config.clone(),
480                                    score,
481                                    evaluation_time,
482                                    worker_id,
483                                    timestamp: start_time,
484                                    additional_metrics: HashMap::new(),
485                                    error: None,
486                                };
487
488                                if let Ok(mut state) = shared_state.write() {
489                                    state.evaluations.push(result);
490                                    state.completed_count += 1;
491
492                                    if score > state.best_score {
493                                        state.best_score = score;
494                                        state.best_hyperparameters = config.clone();
495                                    }
496                                }
497                            }
498                            Err(e) => {
499                                if !matches!(
500                                    self.config.error_handling,
501                                    ErrorHandlingStrategy::FailFast
502                                ) {
503                                    let evaluation_time = start_time.elapsed();
504                                    let result = EvaluationResult {
505                                        hyperparameters: config.clone(),
506                                        score: Float::NEG_INFINITY,
507                                        evaluation_time,
508                                        worker_id,
509                                        timestamp: start_time,
510                                        additional_metrics: HashMap::new(),
511                                        error: Some(e.to_string()),
512                                    };
513
514                                    if let Ok(mut state) = shared_state.write() {
515                                        state.evaluations.push(result);
516                                        state.completed_count += 1;
517                                    }
518                                }
519                            }
520                        }
521                    });
522            });
523
524            evaluations_completed += current_batch_size;
525        }
526
527        self.create_result()
528    }
529
530    /// Parallel Bayesian optimization implementation
531    fn parallel_bayesian_optimization<F>(
532        &mut self,
533        evaluation_fn: Arc<F>,
534        parameter_bounds: &[(Float, Float)],
535        max_evaluations: usize,
536    ) -> Result<ParallelOptimizationResult, Box<dyn std::error::Error>>
537    where
538        F: Fn(&HashMap<String, Float>) -> Result<Float, Box<dyn std::error::Error>>
539            + Send
540            + Sync
541            + 'static,
542    {
543        let (batch_size, acquisition_strategy, synchronization) = match &self.config.strategy {
544            ParallelStrategy::ParallelBayesianOptimization {
545                batch_size,
546                acquisition_strategy,
547                synchronization,
548            } => (
549                *batch_size,
550                acquisition_strategy.clone(),
551                synchronization.clone(),
552            ),
553            _ => unreachable!(),
554        };
555
556        let shared_state = self.shared_state.clone();
557
558        // Initialize with random evaluations
559        let initial_evaluations = batch_size.min(5);
560        self.parallel_random_search(evaluation_fn.clone(), parameter_bounds, initial_evaluations)?;
561
562        let mut evaluations_completed = initial_evaluations;
563
564        while evaluations_completed < max_evaluations {
565            // Update Gaussian Process model
566            self.update_gaussian_process_model()?;
567
568            // Generate next batch using acquisition strategy
569            let next_batch = self.generate_acquisition_batch(
570                &acquisition_strategy,
571                parameter_bounds,
572                batch_size,
573            )?;
574
575            // Evaluate batch in parallel
576            let worker_pool = self.worker_pool.as_ref().expect("operation should succeed");
577            worker_pool.install(|| {
578                next_batch
579                    .par_iter()
580                    .enumerate()
581                    .for_each(|(local_id, config)| {
582                        let worker_id = evaluations_completed + local_id;
583                        let start_time = Instant::now();
584
585                        match evaluation_fn(config) {
586                            Ok(score) => {
587                                let evaluation_time = start_time.elapsed();
588                                let result = EvaluationResult {
589                                    hyperparameters: config.clone(),
590                                    score,
591                                    evaluation_time,
592                                    worker_id,
593                                    timestamp: start_time,
594                                    additional_metrics: HashMap::new(),
595                                    error: None,
596                                };
597
598                                if let Ok(mut state) = shared_state.write() {
599                                    state.evaluations.push(result);
600                                    state.completed_count += 1;
601
602                                    if score > state.best_score {
603                                        state.best_score = score;
604                                        state.best_hyperparameters = config.clone();
605                                    }
606                                }
607                            }
608                            Err(e) => {
609                                if !matches!(
610                                    self.config.error_handling,
611                                    ErrorHandlingStrategy::FailFast
612                                ) {
613                                    let evaluation_time = start_time.elapsed();
614                                    let result = EvaluationResult {
615                                        hyperparameters: config.clone(),
616                                        score: Float::NEG_INFINITY,
617                                        evaluation_time,
618                                        worker_id,
619                                        timestamp: start_time,
620                                        additional_metrics: HashMap::new(),
621                                        error: Some(e.to_string()),
622                                    };
623
624                                    if let Ok(mut state) = shared_state.write() {
625                                        state.evaluations.push(result);
626                                        state.completed_count += 1;
627                                    }
628                                }
629                            }
630                        }
631                    });
632            });
633
634            evaluations_completed += batch_size;
635
636            // Handle synchronization
637            match synchronization {
638                SynchronizationStrategy::Synchronous => {
639                    // Wait for all evaluations in batch to complete
640                    // Already handled by rayon's parallel execution
641                }
642                SynchronizationStrategy::Asynchronous => {
643                    // Continue immediately with next batch
644                    break;
645                }
646                SynchronizationStrategy::Hybrid { sync_interval } => {
647                    if evaluations_completed % sync_interval == 0 {
648                        // Synchronize periodically
649                        std::thread::sleep(Duration::from_millis(10));
650                    }
651                }
652            }
653        }
654
655        self.create_result()
656    }
657
658    /// Asynchronous optimization implementation
659    fn asynchronous_optimization<F>(
660        &mut self,
661        evaluation_fn: Arc<F>,
662        parameter_bounds: &[(Float, Float)],
663        max_evaluations: usize,
664    ) -> Result<ParallelOptimizationResult, Box<dyn std::error::Error>>
665    where
666        F: Fn(&HashMap<String, Float>) -> Result<Float, Box<dyn std::error::Error>>
667            + Send
668            + Sync
669            + 'static,
670    {
671        // Simplified asynchronous optimization using rayon
672        self.parallel_random_search(evaluation_fn, parameter_bounds, max_evaluations)
673    }
674
675    /// Distributed optimization implementation
676    fn distributed_optimization<F>(
677        &mut self,
678        evaluation_fn: Arc<F>,
679        parameter_bounds: &[(Float, Float)],
680        max_evaluations: usize,
681    ) -> Result<ParallelOptimizationResult, Box<dyn std::error::Error>>
682    where
683        F: Fn(&HashMap<String, Float>) -> Result<Float, Box<dyn std::error::Error>>
684            + Send
685            + Sync
686            + 'static,
687    {
688        // Simplified distributed optimization - fallback to parallel random search
689        // In a real implementation, this would coordinate across multiple machines
690        self.parallel_random_search(evaluation_fn, parameter_bounds, max_evaluations)
691    }
692
693    /// Multi-objective parallel optimization implementation
694    fn multi_objective_parallel_optimization<F>(
695        &mut self,
696        evaluation_fn: Arc<F>,
697        parameter_bounds: &[(Float, Float)],
698        max_evaluations: usize,
699    ) -> Result<ParallelOptimizationResult, Box<dyn std::error::Error>>
700    where
701        F: Fn(&HashMap<String, Float>) -> Result<Float, Box<dyn std::error::Error>>
702            + Send
703            + Sync
704            + 'static,
705    {
706        // Simplified multi-objective optimization - use single objective for now
707        self.parallel_random_search(evaluation_fn, parameter_bounds, max_evaluations)
708    }
709
710    /// Generate grid configurations
711    fn generate_grid_configurations(
712        &self,
713        parameter_bounds: &[(Float, Float)],
714        max_evaluations: usize,
715    ) -> Result<Vec<HashMap<String, Float>>, Box<dyn std::error::Error>> {
716        let n_params = parameter_bounds.len();
717        let n_values_per_param = (max_evaluations as Float)
718            .powf(1.0 / n_params as Float)
719            .ceil() as usize;
720
721        let mut configurations = Vec::new();
722        let mut indices = vec![0; n_params];
723
724        loop {
725            let mut config = HashMap::new();
726            for (i, &(low, high)) in parameter_bounds.iter().enumerate() {
727                let value =
728                    low + (high - low) * (indices[i] as Float) / (n_values_per_param - 1) as Float;
729                config.insert(format!("param_{}", i), value);
730            }
731            configurations.push(config);
732
733            // Increment indices
734            let mut carry = 1;
735            for i in 0..n_params {
736                indices[i] += carry;
737                if indices[i] < n_values_per_param {
738                    carry = 0;
739                    break;
740                } else {
741                    indices[i] = 0;
742                }
743            }
744
745            if carry == 1 || configurations.len() >= max_evaluations {
746                break;
747            }
748        }
749
750        Ok(configurations)
751    }
752
753    /// Sample random configuration
754    fn sample_random_configuration(
755        &self,
756        parameter_bounds: &[(Float, Float)],
757        rng: &mut StdRng,
758    ) -> Result<HashMap<String, Float>, Box<dyn std::error::Error>> {
759        let mut config = HashMap::new();
760
761        for (i, &(low, high)) in parameter_bounds.iter().enumerate() {
762            let value = rng.random_range(low..high + 1.0);
763            config.insert(format!("param_{}", i), value);
764        }
765
766        Ok(config)
767    }
768
769    /// Calculate dynamic batch size
770    fn calculate_dynamic_batch_size(
771        &self,
772        base_batch_size: usize,
773    ) -> Result<usize, Box<dyn std::error::Error>> {
774        // Simple heuristic: adjust based on recent evaluation times
775        if let Ok(state) = self.shared_state.read() {
776            if state.evaluations.len() >= 10 {
777                let recent_evaluations = &state.evaluations[state.evaluations.len() - 10..];
778                let avg_time = recent_evaluations
779                    .iter()
780                    .map(|e| e.evaluation_time.as_secs_f64())
781                    .sum::<f64>()
782                    / recent_evaluations.len() as f64;
783
784                // Adjust batch size based on evaluation time
785                if avg_time < 1.0 {
786                    Ok(base_batch_size * 2) // Fast evaluations, increase batch size
787                } else if avg_time > 10.0 {
788                    Ok(base_batch_size / 2) // Slow evaluations, decrease batch size
789                } else {
790                    Ok(base_batch_size)
791                }
792            } else {
793                Ok(base_batch_size)
794            }
795        } else {
796            Ok(base_batch_size)
797        }
798    }
799
800    /// Update Gaussian Process model
801    fn update_gaussian_process_model(&mut self) -> Result<(), Box<dyn std::error::Error>> {
802        if let Ok(mut state) = self.shared_state.write() {
803            let observations: Vec<(Vec<Float>, Float)> = state
804                .evaluations
805                .iter()
806                .filter(|e| e.error.is_none())
807                .map(|e| {
808                    let params: Vec<Float> = e.hyperparameters.values().cloned().collect();
809                    (params, e.score)
810                })
811                .collect();
812
813            if observations.len() >= 3 {
814                let gp = SimplifiedGP {
815                    observations,
816                    hyperparameters: GPHyperparams {
817                        length_scale: 1.0,
818                        signal_variance: 1.0,
819                        noise_variance: 0.1,
820                    },
821                    trained: true,
822                };
823                state.gaussian_process_model = Some(gp);
824            }
825        }
826        Ok(())
827    }
828
829    /// Generate acquisition batch
830    fn generate_acquisition_batch(
831        &self,
832        _acquisition_strategy: &BatchAcquisitionStrategy,
833        parameter_bounds: &[(Float, Float)],
834        batch_size: usize,
835    ) -> Result<Vec<HashMap<String, Float>>, Box<dyn std::error::Error>> {
836        let mut rng = StdRng::seed_from_u64(42); // Fixed seed for reproducibility
837        let mut batch = Vec::new();
838
839        for _ in 0..batch_size {
840            // Simplified acquisition - just sample randomly for now
841            // In a real implementation, this would use the acquisition function
842            batch.push(self.sample_random_configuration(parameter_bounds, &mut rng)?);
843        }
844
845        Ok(batch)
846    }
847
848    /// Create optimization result
849    fn create_result(&self) -> Result<ParallelOptimizationResult, Box<dyn std::error::Error>> {
850        let state = self.shared_state.read().expect("operation should succeed");
851
852        let successful_evaluations = state
853            .evaluations
854            .iter()
855            .filter(|e| e.error.is_none())
856            .count();
857
858        let failed_evaluations = state.evaluations.len() - successful_evaluations;
859
860        let total_evaluation_time: Duration =
861            state.evaluations.iter().map(|e| e.evaluation_time).sum();
862
863        let average_evaluation_time = if state.evaluations.is_empty() {
864            Duration::from_secs(0)
865        } else {
866            total_evaluation_time / state.evaluations.len() as u32
867        };
868
869        // Calculate worker statistics
870        let mut worker_stats = HashMap::new();
871        for eval in &state.evaluations {
872            let stats = worker_stats
873                .entry(eval.worker_id)
874                .or_insert(WorkerStatistics {
875                    worker_id: eval.worker_id,
876                    evaluations_completed: 0,
877                    total_computation_time: Duration::from_secs(0),
878                    idle_time: Duration::from_secs(0),
879                    errors_encountered: 0,
880                    average_evaluation_time: Duration::from_secs(0),
881                });
882
883            stats.evaluations_completed += 1;
884            stats.total_computation_time += eval.evaluation_time;
885            if eval.error.is_some() {
886                stats.errors_encountered += 1;
887            }
888        }
889
890        for stats in worker_stats.values_mut() {
891            if stats.evaluations_completed > 0 {
892                stats.average_evaluation_time =
893                    stats.total_computation_time / stats.evaluations_completed as u32;
894            }
895        }
896
897        Ok(ParallelOptimizationResult {
898            best_hyperparameters: state.best_hyperparameters.clone(),
899            best_score: state.best_score,
900            all_evaluations: state.evaluations.clone(),
901            optimization_statistics: OptimizationStatistics {
902                total_evaluations: state.evaluations.len(),
903                successful_evaluations,
904                failed_evaluations,
905                average_evaluation_time,
906                convergence_rate: 0.1, // Placeholder
907                resource_utilization: ResourceUtilization {
908                    cpu_utilization: 0.8,
909                    memory_utilization: 0.6,
910                    network_utilization: 0.1,
911                    idle_time_percentage: 0.1,
912                },
913            },
914            worker_statistics: worker_stats.into_values().collect(),
915            parallelization_efficiency: successful_evaluations as Float
916                / self.config.max_workers as Float,
917            total_wall_time: total_evaluation_time,
918            total_cpu_time: total_evaluation_time * self.config.max_workers as u32,
919        })
920    }
921}
922
923/// Convenience function for parallel optimization
924pub fn parallel_optimize<F>(
925    evaluation_fn: F,
926    parameter_bounds: &[(Float, Float)],
927    max_evaluations: usize,
928    config: Option<ParallelOptimizationConfig>,
929) -> Result<ParallelOptimizationResult, Box<dyn std::error::Error>>
930where
931    F: Fn(&HashMap<String, Float>) -> Result<Float, Box<dyn std::error::Error>>
932        + Send
933        + Sync
934        + 'static,
935{
936    let config = config.unwrap_or_default();
937    let mut optimizer = ParallelOptimizer::new(config)?;
938    optimizer.optimize(evaluation_fn, parameter_bounds, max_evaluations)
939}
940
941#[allow(non_snake_case)]
942#[cfg(test)]
943mod tests {
944    use super::*;
945
946    fn mock_evaluation_function(
947        hyperparameters: &HashMap<String, Float>,
948    ) -> Result<Float, Box<dyn std::error::Error>> {
949        // Simple quadratic function for testing
950        let score = hyperparameters
951            .values()
952            .map(|&x| -(x - 0.5).powi(2))
953            .sum::<Float>();
954        Ok(score)
955    }
956
957    #[test]
958    fn test_parallel_optimizer_creation() {
959        let config = ParallelOptimizationConfig::default();
960        let optimizer = ParallelOptimizer::new(config);
961        assert!(optimizer.is_ok());
962    }
963
964    #[test]
965    fn test_parallel_random_search() {
966        let config = ParallelOptimizationConfig {
967            strategy: ParallelStrategy::ParallelRandomSearch {
968                batch_size: 4,
969                dynamic_batching: false,
970            },
971            max_workers: 2,
972            ..Default::default()
973        };
974
975        let parameter_bounds = vec![(0.0, 1.0), (0.0, 1.0)];
976
977        let result = parallel_optimize(
978            mock_evaluation_function,
979            &parameter_bounds,
980            10,
981            Some(config),
982        )
983        .expect("operation should succeed");
984
985        assert!(result.best_score <= 0.0); // Max should be 0 for our function
986                                           // Allow for slight overshoot in parallel execution due to batch processing
987        assert!(result.optimization_statistics.total_evaluations >= 10);
988        assert!(result.optimization_statistics.total_evaluations <= 16); // max_workers * batch_size
989        assert!(!result.worker_statistics.is_empty());
990    }
991
992    #[test]
993    fn test_parallel_grid_search() {
994        let config = ParallelOptimizationConfig {
995            strategy: ParallelStrategy::ParallelGridSearch {
996                chunk_size: 2,
997                load_balancing: LoadBalancingStrategy::Static,
998            },
999            max_workers: 2,
1000            ..Default::default()
1001        };
1002
1003        let parameter_bounds = vec![(0.0, 1.0), (0.0, 1.0)];
1004
1005        let result = parallel_optimize(
1006            mock_evaluation_function,
1007            &parameter_bounds,
1008            9, // 3x3 grid
1009            Some(config),
1010        )
1011        .expect("operation should succeed");
1012
1013        assert!(result.best_score <= 0.0);
1014        assert!(result.optimization_statistics.total_evaluations > 0);
1015    }
1016
1017    #[test]
1018    fn test_error_handling() {
1019        let failing_function =
1020            |_: &HashMap<String, Float>| -> Result<Float, Box<dyn std::error::Error>> {
1021                Err("Test error".into())
1022            };
1023
1024        let config = ParallelOptimizationConfig {
1025            error_handling: ErrorHandlingStrategy::SkipErrors,
1026            max_workers: 2,
1027            ..Default::default()
1028        };
1029
1030        let parameter_bounds = vec![(0.0, 1.0)];
1031
1032        let result = parallel_optimize(failing_function, &parameter_bounds, 5, Some(config))
1033            .expect("operation should succeed");
1034
1035        // In parallel execution, evaluations may exceed requested due to batching
1036        assert!(result.optimization_statistics.failed_evaluations >= 5);
1037        assert_eq!(result.optimization_statistics.successful_evaluations, 0);
1038        assert_eq!(
1039            result.optimization_statistics.total_evaluations,
1040            result.optimization_statistics.failed_evaluations
1041        );
1042    }
1043}