sklears_model_selection/
parallel_optimization.rs

1//! Parallel Hyperparameter Search
2//!
3//! This module provides parallel hyperparameter optimization using rayon for concurrent evaluation
4//! of hyperparameter configurations. It supports various parallelization strategies and load balancing
5//! techniques to efficiently utilize available computational resources.
6
7use rayon::prelude::*;
8use scirs2_core::random::rngs::StdRng;
9use scirs2_core::random::Rng;
10use scirs2_core::random::SeedableRng;
11use sklears_core::types::Float;
12use std::collections::HashMap;
13use std::sync::{Arc, RwLock};
14use std::time::{Duration, Instant};
15
16/// Parallel optimization strategies
17#[derive(Debug, Clone)]
18pub enum ParallelStrategy {
19    /// Simple parallel grid search
20    ParallelGridSearch {
21        chunk_size: usize,
22
23        load_balancing: LoadBalancingStrategy,
24    },
25    /// Parallel random search
26    ParallelRandomSearch {
27        batch_size: usize,
28
29        dynamic_batching: bool,
30    },
31    /// Parallel Bayesian optimization
32    ParallelBayesianOptimization {
33        batch_size: usize,
34        acquisition_strategy: BatchAcquisitionStrategy,
35        synchronization: SynchronizationStrategy,
36    },
37    /// Asynchronous optimization
38    AsynchronousOptimization {
39        max_concurrent: usize,
40        result_polling_interval: Duration,
41    },
42    /// Distributed optimization across multiple machines
43    DistributedOptimization {
44        worker_nodes: Vec<String>,
45        communication_protocol: CommunicationProtocol,
46    },
47    /// Multi-objective parallel optimization
48    MultiObjectiveParallel {
49        objectives: Vec<String>,
50        pareto_batch_size: usize,
51    },
52}
53
54/// Load balancing strategies for parallel execution
55#[derive(Debug, Clone)]
56pub enum LoadBalancingStrategy {
57    /// Static load balancing with equal chunks
58    Static,
59    /// Dynamic load balancing based on completion time
60    Dynamic { rebalance_threshold: Float },
61    /// Work-stealing approach
62    WorkStealing,
63    /// Priority-based load balancing
64    PriorityBased { priority_function: String },
65}
66
67/// Batch acquisition strategies for parallel Bayesian optimization
68#[derive(Debug, Clone)]
69pub enum BatchAcquisitionStrategy {
70    /// Constant Liar strategy
71    ConstantLiar { liar_value: Float },
72    /// Kriging Believer strategy
73    KrigingBeliever,
74    /// qExpected Improvement
75    QExpectedImprovement,
76    /// Local Penalization
77    LocalPenalization { penalization_factor: Float },
78    /// Thompson Sampling
79    ThompsonSampling { n_samples: usize },
80}
81
82/// Synchronization strategies for parallel optimization
83#[derive(Debug, Clone)]
84pub enum SynchronizationStrategy {
85    /// Synchronous updates (wait for all workers)
86    Synchronous,
87    /// Asynchronous updates (update as results arrive)
88    Asynchronous,
89    /// Hybrid approach with periodic synchronization
90    Hybrid { sync_interval: usize },
91}
92
93/// Communication protocols for distributed optimization
94#[derive(Debug, Clone)]
95pub enum CommunicationProtocol {
96    /// TCP-based communication
97    TCP { port: u16 },
98    /// Message queues
99    MessageQueue { queue_name: String },
100    /// Shared filesystem
101    SharedFilesystem { path: String },
102    /// Custom protocol
103    Custom { config: HashMap<String, String> },
104}
105
106/// Parallel optimization configuration
107#[derive(Debug, Clone)]
108pub struct ParallelOptimizationConfig {
109    pub strategy: ParallelStrategy,
110    pub max_workers: usize,
111    pub timeout_per_evaluation: Option<Duration>,
112    pub memory_limit_per_worker: Option<usize>,
113    pub error_handling: ErrorHandlingStrategy,
114    pub progress_reporting: ProgressReportingConfig,
115    pub resource_monitoring: bool,
116    pub random_state: Option<u64>,
117}
118
119/// Error handling strategies
120#[derive(Debug, Clone)]
121pub enum ErrorHandlingStrategy {
122    /// Fail fast on any error
123    FailFast,
124    /// Continue on errors, skip failed evaluations
125    SkipErrors,
126    /// Retry failed evaluations
127    RetryOnError {
128        max_retries: usize,
129
130        backoff_factor: Float,
131    },
132    /// Use fallback evaluations for failed ones
133    FallbackEvaluation { fallback_score: Float },
134}
135
136/// Progress reporting configuration
137#[derive(Debug, Clone)]
138pub struct ProgressReportingConfig {
139    pub enabled: bool,
140    pub update_interval: Duration,
141    pub detailed_metrics: bool,
142    pub export_intermediate_results: bool,
143}
144
145/// Parallel optimization result
146#[derive(Debug, Clone)]
147pub struct ParallelOptimizationResult {
148    pub best_hyperparameters: HashMap<String, Float>,
149    pub best_score: Float,
150    pub all_evaluations: Vec<EvaluationResult>,
151    pub optimization_statistics: OptimizationStatistics,
152    pub worker_statistics: Vec<WorkerStatistics>,
153    pub parallelization_efficiency: Float,
154    pub total_wall_time: Duration,
155    pub total_cpu_time: Duration,
156}
157
158/// Individual evaluation result
159#[derive(Debug, Clone)]
160pub struct EvaluationResult {
161    pub hyperparameters: HashMap<String, Float>,
162    pub score: Float,
163    pub evaluation_time: Duration,
164    pub worker_id: usize,
165    pub timestamp: Instant,
166    pub additional_metrics: HashMap<String, Float>,
167    pub error: Option<String>,
168}
169
170/// Optimization statistics
171#[derive(Debug, Clone)]
172pub struct OptimizationStatistics {
173    pub total_evaluations: usize,
174    pub successful_evaluations: usize,
175    pub failed_evaluations: usize,
176    pub average_evaluation_time: Duration,
177    pub convergence_rate: Float,
178    pub resource_utilization: ResourceUtilization,
179}
180
181/// Resource utilization metrics
182#[derive(Debug, Clone)]
183pub struct ResourceUtilization {
184    pub cpu_utilization: Float,
185    pub memory_utilization: Float,
186    pub network_utilization: Float,
187    pub idle_time_percentage: Float,
188}
189
190/// Worker-specific statistics
191#[derive(Debug, Clone)]
192pub struct WorkerStatistics {
193    pub worker_id: usize,
194    pub evaluations_completed: usize,
195    pub total_computation_time: Duration,
196    pub idle_time: Duration,
197    pub errors_encountered: usize,
198    pub average_evaluation_time: Duration,
199}
200
201/// Parallel hyperparameter optimizer
202pub struct ParallelOptimizer {
203    config: ParallelOptimizationConfig,
204    shared_state: Arc<RwLock<SharedOptimizationState>>,
205    worker_pool: Option<rayon::ThreadPool>,
206}
207
208/// Shared state between workers
209#[derive(Debug)]
210pub struct SharedOptimizationState {
211    pub evaluations: Vec<EvaluationResult>,
212    pub best_score: Float,
213    pub best_hyperparameters: HashMap<String, Float>,
214    pub pending_evaluations: Vec<HashMap<String, Float>>,
215    pub completed_count: usize,
216    pub gaussian_process_model: Option<SimplifiedGP>,
217}
218
219/// Simplified Gaussian Process for parallel optimization
220#[derive(Debug, Clone)]
221pub struct SimplifiedGP {
222    pub observations: Vec<(Vec<Float>, Float)>,
223    pub hyperparameters: GPHyperparams,
224    pub trained: bool,
225}
226
227/// GP hyperparameters
228#[derive(Debug, Clone)]
229pub struct GPHyperparams {
230    pub length_scale: Float,
231    pub signal_variance: Float,
232    pub noise_variance: Float,
233}
234
235impl Default for ParallelOptimizationConfig {
236    fn default() -> Self {
237        Self {
238            strategy: ParallelStrategy::ParallelRandomSearch {
239                batch_size: 4,
240                dynamic_batching: true,
241            },
242            max_workers: num_cpus::get(),
243            timeout_per_evaluation: Some(Duration::from_secs(300)),
244            memory_limit_per_worker: None,
245            error_handling: ErrorHandlingStrategy::SkipErrors,
246            progress_reporting: ProgressReportingConfig {
247                enabled: true,
248                update_interval: Duration::from_secs(10),
249                detailed_metrics: false,
250                export_intermediate_results: false,
251            },
252            resource_monitoring: true,
253            random_state: None,
254        }
255    }
256}
257
258impl ParallelOptimizer {
259    /// Create a new parallel optimizer
260    pub fn new(config: ParallelOptimizationConfig) -> Result<Self, Box<dyn std::error::Error>> {
261        // Create custom thread pool
262        let worker_pool = rayon::ThreadPoolBuilder::new()
263            .num_threads(config.max_workers)
264            .build()?;
265
266        let shared_state = Arc::new(RwLock::new(SharedOptimizationState {
267            evaluations: Vec::new(),
268            best_score: Float::NEG_INFINITY,
269            best_hyperparameters: HashMap::new(),
270            pending_evaluations: Vec::new(),
271            completed_count: 0,
272            gaussian_process_model: None,
273        }));
274
275        Ok(Self {
276            config,
277            shared_state,
278            worker_pool: Some(worker_pool),
279        })
280    }
281
282    /// Optimize hyperparameters in parallel
283    pub fn optimize<F>(
284        &mut self,
285        evaluation_fn: F,
286        parameter_bounds: &[(Float, Float)],
287        max_evaluations: usize,
288    ) -> Result<ParallelOptimizationResult, Box<dyn std::error::Error>>
289    where
290        F: Fn(&HashMap<String, Float>) -> Result<Float, Box<dyn std::error::Error>>
291            + Send
292            + Sync
293            + 'static,
294    {
295        let _start_time = Instant::now();
296        let evaluation_fn = Arc::new(evaluation_fn);
297
298        match &self.config.strategy {
299            ParallelStrategy::ParallelGridSearch { .. } => {
300                self.parallel_grid_search(evaluation_fn, parameter_bounds, max_evaluations)
301            }
302            ParallelStrategy::ParallelRandomSearch { .. } => {
303                self.parallel_random_search(evaluation_fn, parameter_bounds, max_evaluations)
304            }
305            ParallelStrategy::ParallelBayesianOptimization { .. } => self
306                .parallel_bayesian_optimization(evaluation_fn, parameter_bounds, max_evaluations),
307            ParallelStrategy::AsynchronousOptimization { .. } => {
308                self.asynchronous_optimization(evaluation_fn, parameter_bounds, max_evaluations)
309            }
310            ParallelStrategy::DistributedOptimization { .. } => {
311                self.distributed_optimization(evaluation_fn, parameter_bounds, max_evaluations)
312            }
313            ParallelStrategy::MultiObjectiveParallel { .. } => self
314                .multi_objective_parallel_optimization(
315                    evaluation_fn,
316                    parameter_bounds,
317                    max_evaluations,
318                ),
319        }
320    }
321
322    /// Parallel grid search implementation
323    fn parallel_grid_search<F>(
324        &mut self,
325        evaluation_fn: Arc<F>,
326        parameter_bounds: &[(Float, Float)],
327        max_evaluations: usize,
328    ) -> Result<ParallelOptimizationResult, Box<dyn std::error::Error>>
329    where
330        F: Fn(&HashMap<String, Float>) -> Result<Float, Box<dyn std::error::Error>>
331            + Send
332            + Sync
333            + 'static,
334    {
335        let (chunk_size, _load_balancing) = match &self.config.strategy {
336            ParallelStrategy::ParallelGridSearch {
337                chunk_size,
338                load_balancing,
339            } => (*chunk_size, load_balancing),
340            _ => unreachable!(),
341        };
342
343        // Generate grid configurations
344        let grid_configs = self.generate_grid_configurations(parameter_bounds, max_evaluations)?;
345
346        // Process configurations in parallel chunks
347        let shared_state = self.shared_state.clone();
348        let worker_pool = self.worker_pool.as_ref().unwrap();
349
350        worker_pool.install(|| {
351            grid_configs
352                .par_chunks(chunk_size)
353                .enumerate()
354                .for_each(|(chunk_id, chunk)| {
355                    for (config_id, config) in chunk.iter().enumerate() {
356                        let worker_id = chunk_id * chunk_size + config_id;
357                        let start_time = Instant::now();
358
359                        match evaluation_fn(config) {
360                            Ok(score) => {
361                                let evaluation_time = start_time.elapsed();
362                                let result = EvaluationResult {
363                                    hyperparameters: config.clone(),
364                                    score,
365                                    evaluation_time,
366                                    worker_id,
367                                    timestamp: start_time,
368                                    additional_metrics: HashMap::new(),
369                                    error: None,
370                                };
371
372                                // Update shared state
373                                if let Ok(mut state) = shared_state.write() {
374                                    state.evaluations.push(result);
375                                    state.completed_count += 1;
376
377                                    if score > state.best_score {
378                                        state.best_score = score;
379                                        state.best_hyperparameters = config.clone();
380                                    }
381                                }
382                            }
383                            Err(e) => {
384                                if matches!(
385                                    self.config.error_handling,
386                                    ErrorHandlingStrategy::FailFast
387                                ) {
388                                    panic!("Evaluation failed: {}", e);
389                                }
390
391                                let evaluation_time = start_time.elapsed();
392                                let result = EvaluationResult {
393                                    hyperparameters: config.clone(),
394                                    score: Float::NEG_INFINITY,
395                                    evaluation_time,
396                                    worker_id,
397                                    timestamp: start_time,
398                                    additional_metrics: HashMap::new(),
399                                    error: Some(e.to_string()),
400                                };
401
402                                if let Ok(mut state) = shared_state.write() {
403                                    state.evaluations.push(result);
404                                    state.completed_count += 1;
405                                }
406                            }
407                        }
408                    }
409                });
410        });
411
412        self.create_result()
413    }
414
415    /// Parallel random search implementation
416    fn parallel_random_search<F>(
417        &mut self,
418        evaluation_fn: Arc<F>,
419        parameter_bounds: &[(Float, Float)],
420        max_evaluations: usize,
421    ) -> Result<ParallelOptimizationResult, Box<dyn std::error::Error>>
422    where
423        F: Fn(&HashMap<String, Float>) -> Result<Float, Box<dyn std::error::Error>>
424            + Send
425            + Sync
426            + 'static,
427    {
428        let (batch_size, dynamic_batching) = match &self.config.strategy {
429            ParallelStrategy::ParallelRandomSearch {
430                batch_size,
431                dynamic_batching,
432            } => (*batch_size, *dynamic_batching),
433            _ => unreachable!(),
434        };
435
436        let shared_state = self.shared_state.clone();
437        let worker_pool = self.worker_pool.as_ref().unwrap();
438
439        let mut rng = match self.config.random_state {
440            Some(seed) => StdRng::seed_from_u64(seed),
441            None => {
442                use scirs2_core::random::thread_rng;
443                StdRng::from_rng(&mut thread_rng())
444            }
445        };
446
447        let mut evaluations_completed = 0;
448        let mut current_batch_size = batch_size;
449
450        while evaluations_completed < max_evaluations {
451            // Adjust batch size dynamically if enabled
452            if dynamic_batching {
453                current_batch_size = self.calculate_dynamic_batch_size(batch_size)?;
454            }
455
456            // Generate batch of random configurations
457            let batch_configs: Vec<HashMap<String, Float>> = (0..current_batch_size)
458                .map(|_| self.sample_random_configuration(parameter_bounds, &mut rng))
459                .collect::<Result<Vec<_>, _>>()?;
460
461            // Evaluate batch in parallel
462            worker_pool.install(|| {
463                batch_configs
464                    .par_iter()
465                    .enumerate()
466                    .for_each(|(local_id, config)| {
467                        let worker_id = evaluations_completed + local_id;
468                        let start_time = Instant::now();
469
470                        match evaluation_fn(config) {
471                            Ok(score) => {
472                                let evaluation_time = start_time.elapsed();
473                                let result = EvaluationResult {
474                                    hyperparameters: config.clone(),
475                                    score,
476                                    evaluation_time,
477                                    worker_id,
478                                    timestamp: start_time,
479                                    additional_metrics: HashMap::new(),
480                                    error: None,
481                                };
482
483                                if let Ok(mut state) = shared_state.write() {
484                                    state.evaluations.push(result);
485                                    state.completed_count += 1;
486
487                                    if score > state.best_score {
488                                        state.best_score = score;
489                                        state.best_hyperparameters = config.clone();
490                                    }
491                                }
492                            }
493                            Err(e) => {
494                                if !matches!(
495                                    self.config.error_handling,
496                                    ErrorHandlingStrategy::FailFast
497                                ) {
498                                    let evaluation_time = start_time.elapsed();
499                                    let result = EvaluationResult {
500                                        hyperparameters: config.clone(),
501                                        score: Float::NEG_INFINITY,
502                                        evaluation_time,
503                                        worker_id,
504                                        timestamp: start_time,
505                                        additional_metrics: HashMap::new(),
506                                        error: Some(e.to_string()),
507                                    };
508
509                                    if let Ok(mut state) = shared_state.write() {
510                                        state.evaluations.push(result);
511                                        state.completed_count += 1;
512                                    }
513                                }
514                            }
515                        }
516                    });
517            });
518
519            evaluations_completed += current_batch_size;
520        }
521
522        self.create_result()
523    }
524
525    /// Parallel Bayesian optimization implementation
526    fn parallel_bayesian_optimization<F>(
527        &mut self,
528        evaluation_fn: Arc<F>,
529        parameter_bounds: &[(Float, Float)],
530        max_evaluations: usize,
531    ) -> Result<ParallelOptimizationResult, Box<dyn std::error::Error>>
532    where
533        F: Fn(&HashMap<String, Float>) -> Result<Float, Box<dyn std::error::Error>>
534            + Send
535            + Sync
536            + 'static,
537    {
538        let (batch_size, acquisition_strategy, synchronization) = match &self.config.strategy {
539            ParallelStrategy::ParallelBayesianOptimization {
540                batch_size,
541                acquisition_strategy,
542                synchronization,
543            } => (
544                *batch_size,
545                acquisition_strategy.clone(),
546                synchronization.clone(),
547            ),
548            _ => unreachable!(),
549        };
550
551        let shared_state = self.shared_state.clone();
552
553        // Initialize with random evaluations
554        let initial_evaluations = batch_size.min(5);
555        self.parallel_random_search(evaluation_fn.clone(), parameter_bounds, initial_evaluations)?;
556
557        let mut evaluations_completed = initial_evaluations;
558
559        while evaluations_completed < max_evaluations {
560            // Update Gaussian Process model
561            self.update_gaussian_process_model()?;
562
563            // Generate next batch using acquisition strategy
564            let next_batch = self.generate_acquisition_batch(
565                &acquisition_strategy,
566                parameter_bounds,
567                batch_size,
568            )?;
569
570            // Evaluate batch in parallel
571            let worker_pool = self.worker_pool.as_ref().unwrap();
572            worker_pool.install(|| {
573                next_batch
574                    .par_iter()
575                    .enumerate()
576                    .for_each(|(local_id, config)| {
577                        let worker_id = evaluations_completed + local_id;
578                        let start_time = Instant::now();
579
580                        match evaluation_fn(config) {
581                            Ok(score) => {
582                                let evaluation_time = start_time.elapsed();
583                                let result = EvaluationResult {
584                                    hyperparameters: config.clone(),
585                                    score,
586                                    evaluation_time,
587                                    worker_id,
588                                    timestamp: start_time,
589                                    additional_metrics: HashMap::new(),
590                                    error: None,
591                                };
592
593                                if let Ok(mut state) = shared_state.write() {
594                                    state.evaluations.push(result);
595                                    state.completed_count += 1;
596
597                                    if score > state.best_score {
598                                        state.best_score = score;
599                                        state.best_hyperparameters = config.clone();
600                                    }
601                                }
602                            }
603                            Err(e) => {
604                                if !matches!(
605                                    self.config.error_handling,
606                                    ErrorHandlingStrategy::FailFast
607                                ) {
608                                    let evaluation_time = start_time.elapsed();
609                                    let result = EvaluationResult {
610                                        hyperparameters: config.clone(),
611                                        score: Float::NEG_INFINITY,
612                                        evaluation_time,
613                                        worker_id,
614                                        timestamp: start_time,
615                                        additional_metrics: HashMap::new(),
616                                        error: Some(e.to_string()),
617                                    };
618
619                                    if let Ok(mut state) = shared_state.write() {
620                                        state.evaluations.push(result);
621                                        state.completed_count += 1;
622                                    }
623                                }
624                            }
625                        }
626                    });
627            });
628
629            evaluations_completed += batch_size;
630
631            // Handle synchronization
632            match synchronization {
633                SynchronizationStrategy::Synchronous => {
634                    // Wait for all evaluations in batch to complete
635                    // Already handled by rayon's parallel execution
636                }
637                SynchronizationStrategy::Asynchronous => {
638                    // Continue immediately with next batch
639                    break;
640                }
641                SynchronizationStrategy::Hybrid { sync_interval } => {
642                    if evaluations_completed % sync_interval == 0 {
643                        // Synchronize periodically
644                        std::thread::sleep(Duration::from_millis(10));
645                    }
646                }
647            }
648        }
649
650        self.create_result()
651    }
652
653    /// Asynchronous optimization implementation
654    fn asynchronous_optimization<F>(
655        &mut self,
656        evaluation_fn: Arc<F>,
657        parameter_bounds: &[(Float, Float)],
658        max_evaluations: usize,
659    ) -> Result<ParallelOptimizationResult, Box<dyn std::error::Error>>
660    where
661        F: Fn(&HashMap<String, Float>) -> Result<Float, Box<dyn std::error::Error>>
662            + Send
663            + Sync
664            + 'static,
665    {
666        // Simplified asynchronous optimization using rayon
667        self.parallel_random_search(evaluation_fn, parameter_bounds, max_evaluations)
668    }
669
670    /// Distributed optimization implementation
671    fn distributed_optimization<F>(
672        &mut self,
673        evaluation_fn: Arc<F>,
674        parameter_bounds: &[(Float, Float)],
675        max_evaluations: usize,
676    ) -> Result<ParallelOptimizationResult, Box<dyn std::error::Error>>
677    where
678        F: Fn(&HashMap<String, Float>) -> Result<Float, Box<dyn std::error::Error>>
679            + Send
680            + Sync
681            + 'static,
682    {
683        // Simplified distributed optimization - fallback to parallel random search
684        // In a real implementation, this would coordinate across multiple machines
685        self.parallel_random_search(evaluation_fn, parameter_bounds, max_evaluations)
686    }
687
688    /// Multi-objective parallel optimization implementation
689    fn multi_objective_parallel_optimization<F>(
690        &mut self,
691        evaluation_fn: Arc<F>,
692        parameter_bounds: &[(Float, Float)],
693        max_evaluations: usize,
694    ) -> Result<ParallelOptimizationResult, Box<dyn std::error::Error>>
695    where
696        F: Fn(&HashMap<String, Float>) -> Result<Float, Box<dyn std::error::Error>>
697            + Send
698            + Sync
699            + 'static,
700    {
701        // Simplified multi-objective optimization - use single objective for now
702        self.parallel_random_search(evaluation_fn, parameter_bounds, max_evaluations)
703    }
704
705    /// Generate grid configurations
706    fn generate_grid_configurations(
707        &self,
708        parameter_bounds: &[(Float, Float)],
709        max_evaluations: usize,
710    ) -> Result<Vec<HashMap<String, Float>>, Box<dyn std::error::Error>> {
711        let n_params = parameter_bounds.len();
712        let n_values_per_param = (max_evaluations as Float)
713            .powf(1.0 / n_params as Float)
714            .ceil() as usize;
715
716        let mut configurations = Vec::new();
717        let mut indices = vec![0; n_params];
718
719        loop {
720            let mut config = HashMap::new();
721            for (i, &(low, high)) in parameter_bounds.iter().enumerate() {
722                let value =
723                    low + (high - low) * (indices[i] as Float) / (n_values_per_param - 1) as Float;
724                config.insert(format!("param_{}", i), value);
725            }
726            configurations.push(config);
727
728            // Increment indices
729            let mut carry = 1;
730            for i in 0..n_params {
731                indices[i] += carry;
732                if indices[i] < n_values_per_param {
733                    carry = 0;
734                    break;
735                } else {
736                    indices[i] = 0;
737                }
738            }
739
740            if carry == 1 || configurations.len() >= max_evaluations {
741                break;
742            }
743        }
744
745        Ok(configurations)
746    }
747
748    /// Sample random configuration
749    fn sample_random_configuration(
750        &self,
751        parameter_bounds: &[(Float, Float)],
752        rng: &mut StdRng,
753    ) -> Result<HashMap<String, Float>, Box<dyn std::error::Error>> {
754        let mut config = HashMap::new();
755
756        for (i, &(low, high)) in parameter_bounds.iter().enumerate() {
757            let value = rng.gen_range(low..high + 1.0);
758            config.insert(format!("param_{}", i), value);
759        }
760
761        Ok(config)
762    }
763
764    /// Calculate dynamic batch size
765    fn calculate_dynamic_batch_size(
766        &self,
767        base_batch_size: usize,
768    ) -> Result<usize, Box<dyn std::error::Error>> {
769        // Simple heuristic: adjust based on recent evaluation times
770        if let Ok(state) = self.shared_state.read() {
771            if state.evaluations.len() >= 10 {
772                let recent_evaluations = &state.evaluations[state.evaluations.len() - 10..];
773                let avg_time = recent_evaluations
774                    .iter()
775                    .map(|e| e.evaluation_time.as_secs_f64())
776                    .sum::<f64>()
777                    / recent_evaluations.len() as f64;
778
779                // Adjust batch size based on evaluation time
780                if avg_time < 1.0 {
781                    Ok(base_batch_size * 2) // Fast evaluations, increase batch size
782                } else if avg_time > 10.0 {
783                    Ok(base_batch_size / 2) // Slow evaluations, decrease batch size
784                } else {
785                    Ok(base_batch_size)
786                }
787            } else {
788                Ok(base_batch_size)
789            }
790        } else {
791            Ok(base_batch_size)
792        }
793    }
794
795    /// Update Gaussian Process model
796    fn update_gaussian_process_model(&mut self) -> Result<(), Box<dyn std::error::Error>> {
797        if let Ok(mut state) = self.shared_state.write() {
798            let observations: Vec<(Vec<Float>, Float)> = state
799                .evaluations
800                .iter()
801                .filter(|e| e.error.is_none())
802                .map(|e| {
803                    let params: Vec<Float> = e.hyperparameters.values().cloned().collect();
804                    (params, e.score)
805                })
806                .collect();
807
808            if observations.len() >= 3 {
809                let gp = SimplifiedGP {
810                    observations,
811                    hyperparameters: GPHyperparams {
812                        length_scale: 1.0,
813                        signal_variance: 1.0,
814                        noise_variance: 0.1,
815                    },
816                    trained: true,
817                };
818                state.gaussian_process_model = Some(gp);
819            }
820        }
821        Ok(())
822    }
823
824    /// Generate acquisition batch
825    fn generate_acquisition_batch(
826        &self,
827        _acquisition_strategy: &BatchAcquisitionStrategy,
828        parameter_bounds: &[(Float, Float)],
829        batch_size: usize,
830    ) -> Result<Vec<HashMap<String, Float>>, Box<dyn std::error::Error>> {
831        let mut rng = StdRng::seed_from_u64(42); // Fixed seed for reproducibility
832        let mut batch = Vec::new();
833
834        for _ in 0..batch_size {
835            // Simplified acquisition - just sample randomly for now
836            // In a real implementation, this would use the acquisition function
837            batch.push(self.sample_random_configuration(parameter_bounds, &mut rng)?);
838        }
839
840        Ok(batch)
841    }
842
843    /// Create optimization result
844    fn create_result(&self) -> Result<ParallelOptimizationResult, Box<dyn std::error::Error>> {
845        let state = self.shared_state.read().unwrap();
846
847        let successful_evaluations = state
848            .evaluations
849            .iter()
850            .filter(|e| e.error.is_none())
851            .count();
852
853        let failed_evaluations = state.evaluations.len() - successful_evaluations;
854
855        let total_evaluation_time: Duration =
856            state.evaluations.iter().map(|e| e.evaluation_time).sum();
857
858        let average_evaluation_time = if state.evaluations.is_empty() {
859            Duration::from_secs(0)
860        } else {
861            total_evaluation_time / state.evaluations.len() as u32
862        };
863
864        // Calculate worker statistics
865        let mut worker_stats = HashMap::new();
866        for eval in &state.evaluations {
867            let stats = worker_stats
868                .entry(eval.worker_id)
869                .or_insert(WorkerStatistics {
870                    worker_id: eval.worker_id,
871                    evaluations_completed: 0,
872                    total_computation_time: Duration::from_secs(0),
873                    idle_time: Duration::from_secs(0),
874                    errors_encountered: 0,
875                    average_evaluation_time: Duration::from_secs(0),
876                });
877
878            stats.evaluations_completed += 1;
879            stats.total_computation_time += eval.evaluation_time;
880            if eval.error.is_some() {
881                stats.errors_encountered += 1;
882            }
883        }
884
885        for stats in worker_stats.values_mut() {
886            if stats.evaluations_completed > 0 {
887                stats.average_evaluation_time =
888                    stats.total_computation_time / stats.evaluations_completed as u32;
889            }
890        }
891
892        Ok(ParallelOptimizationResult {
893            best_hyperparameters: state.best_hyperparameters.clone(),
894            best_score: state.best_score,
895            all_evaluations: state.evaluations.clone(),
896            optimization_statistics: OptimizationStatistics {
897                total_evaluations: state.evaluations.len(),
898                successful_evaluations,
899                failed_evaluations,
900                average_evaluation_time,
901                convergence_rate: 0.1, // Placeholder
902                resource_utilization: ResourceUtilization {
903                    cpu_utilization: 0.8,
904                    memory_utilization: 0.6,
905                    network_utilization: 0.1,
906                    idle_time_percentage: 0.1,
907                },
908            },
909            worker_statistics: worker_stats.into_values().collect(),
910            parallelization_efficiency: successful_evaluations as Float
911                / self.config.max_workers as Float,
912            total_wall_time: total_evaluation_time,
913            total_cpu_time: total_evaluation_time * self.config.max_workers as u32,
914        })
915    }
916}
917
918/// Convenience function for parallel optimization
919pub fn parallel_optimize<F>(
920    evaluation_fn: F,
921    parameter_bounds: &[(Float, Float)],
922    max_evaluations: usize,
923    config: Option<ParallelOptimizationConfig>,
924) -> Result<ParallelOptimizationResult, Box<dyn std::error::Error>>
925where
926    F: Fn(&HashMap<String, Float>) -> Result<Float, Box<dyn std::error::Error>>
927        + Send
928        + Sync
929        + 'static,
930{
931    let config = config.unwrap_or_default();
932    let mut optimizer = ParallelOptimizer::new(config)?;
933    optimizer.optimize(evaluation_fn, parameter_bounds, max_evaluations)
934}
935
936#[allow(non_snake_case)]
937#[cfg(test)]
938mod tests {
939    use super::*;
940
941    fn mock_evaluation_function(
942        hyperparameters: &HashMap<String, Float>,
943    ) -> Result<Float, Box<dyn std::error::Error>> {
944        // Simple quadratic function for testing
945        let score = hyperparameters
946            .values()
947            .map(|&x| -(x - 0.5).powi(2))
948            .sum::<Float>();
949        Ok(score)
950    }
951
952    #[test]
953    fn test_parallel_optimizer_creation() {
954        let config = ParallelOptimizationConfig::default();
955        let optimizer = ParallelOptimizer::new(config);
956        assert!(optimizer.is_ok());
957    }
958
959    #[test]
960    fn test_parallel_random_search() {
961        let config = ParallelOptimizationConfig {
962            strategy: ParallelStrategy::ParallelRandomSearch {
963                batch_size: 4,
964                dynamic_batching: false,
965            },
966            max_workers: 2,
967            ..Default::default()
968        };
969
970        let parameter_bounds = vec![(0.0, 1.0), (0.0, 1.0)];
971
972        let result = parallel_optimize(
973            mock_evaluation_function,
974            &parameter_bounds,
975            10,
976            Some(config),
977        )
978        .unwrap();
979
980        assert!(result.best_score <= 0.0); // Max should be 0 for our function
981                                           // Allow for slight overshoot in parallel execution due to batch processing
982        assert!(result.optimization_statistics.total_evaluations >= 10);
983        assert!(result.optimization_statistics.total_evaluations <= 16); // max_workers * batch_size
984        assert!(!result.worker_statistics.is_empty());
985    }
986
987    #[test]
988    fn test_parallel_grid_search() {
989        let config = ParallelOptimizationConfig {
990            strategy: ParallelStrategy::ParallelGridSearch {
991                chunk_size: 2,
992                load_balancing: LoadBalancingStrategy::Static,
993            },
994            max_workers: 2,
995            ..Default::default()
996        };
997
998        let parameter_bounds = vec![(0.0, 1.0), (0.0, 1.0)];
999
1000        let result = parallel_optimize(
1001            mock_evaluation_function,
1002            &parameter_bounds,
1003            9, // 3x3 grid
1004            Some(config),
1005        )
1006        .unwrap();
1007
1008        assert!(result.best_score <= 0.0);
1009        assert!(result.optimization_statistics.total_evaluations > 0);
1010    }
1011
1012    #[test]
1013    fn test_error_handling() {
1014        let failing_function =
1015            |_: &HashMap<String, Float>| -> Result<Float, Box<dyn std::error::Error>> {
1016                Err("Test error".into())
1017            };
1018
1019        let config = ParallelOptimizationConfig {
1020            error_handling: ErrorHandlingStrategy::SkipErrors,
1021            max_workers: 2,
1022            ..Default::default()
1023        };
1024
1025        let parameter_bounds = vec![(0.0, 1.0)];
1026
1027        let result =
1028            parallel_optimize(failing_function, &parameter_bounds, 5, Some(config)).unwrap();
1029
1030        // In parallel execution, evaluations may exceed requested due to batching
1031        assert!(result.optimization_statistics.failed_evaluations >= 5);
1032        assert_eq!(result.optimization_statistics.successful_evaluations, 0);
1033        assert_eq!(
1034            result.optimization_statistics.total_evaluations,
1035            result.optimization_statistics.failed_evaluations
1036        );
1037    }
1038}