Skip to main content

scirs2_optimize/bayesian/
optimizer.rs

1//! Bayesian Optimizer -- the main driver for Bayesian optimization.
2//!
3//! Orchestrates the GP surrogate, acquisition function, and sampling strategy
4//! into a full sequential/batch optimization loop.
5//!
6//! # Features
7//!
8//! - Configurable surrogate (GP with any kernel)
9//! - Pluggable acquisition functions (EI, PI, UCB, KG, Thompson, batch variants)
10//! - Initial design via Latin Hypercube, Sobol, Halton, or random sampling
11//! - Sequential and batch optimization loops
12//! - Multi-objective Bayesian optimization via ParEGO scalarization
13//! - Constraint handling via augmented acquisition
14//! - Warm-starting from previous evaluations
15
16use scirs2_core::ndarray::{Array1, Array2, ArrayView1};
17use scirs2_core::random::rngs::StdRng;
18use scirs2_core::random::{Rng, RngExt, SeedableRng};
19
20use crate::error::{OptimizeError, OptimizeResult};
21
22use super::acquisition::{AcquisitionFn, AcquisitionType, ExpectedImprovement};
23use super::gp::{GpSurrogate, GpSurrogateConfig, RbfKernel, SurrogateKernel};
24use super::sampling::{generate_samples, SamplingConfig, SamplingStrategy};
25
26// ---------------------------------------------------------------------------
27// Configuration
28// ---------------------------------------------------------------------------
29
30/// Configuration for the Bayesian optimizer.
31#[derive(Clone)]
32pub struct BayesianOptimizerConfig {
33    /// Acquisition function type.
34    pub acquisition: AcquisitionType,
35    /// Sampling strategy for initial design.
36    pub initial_design: SamplingStrategy,
37    /// Number of initial random/quasi-random points.
38    pub n_initial: usize,
39    /// Number of restarts when optimising the acquisition function.
40    pub acq_n_restarts: usize,
41    /// Number of random candidates evaluated per restart when optimising acquisition.
42    pub acq_n_candidates: usize,
43    /// GP surrogate configuration.
44    pub gp_config: GpSurrogateConfig,
45    /// Random seed for reproducibility.
46    pub seed: Option<u64>,
47    /// Verbosity level (0 = silent, 1 = summary, 2 = per-iteration).
48    pub verbose: usize,
49}
50
51impl Default for BayesianOptimizerConfig {
52    fn default() -> Self {
53        Self {
54            acquisition: AcquisitionType::EI { xi: 0.01 },
55            initial_design: SamplingStrategy::LatinHypercube,
56            n_initial: 10,
57            acq_n_restarts: 5,
58            acq_n_candidates: 200,
59            gp_config: GpSurrogateConfig::default(),
60            seed: None,
61            verbose: 0,
62        }
63    }
64}
65
66// ---------------------------------------------------------------------------
67// Observation record
68// ---------------------------------------------------------------------------
69
70/// A single evaluated observation.
71#[derive(Debug, Clone)]
72pub struct Observation {
73    /// Input point.
74    pub x: Array1<f64>,
75    /// Objective function value.
76    pub y: f64,
77    /// Constraint violation values (empty if no constraints).
78    pub constraints: Vec<f64>,
79    /// Whether this point is feasible (all constraints satisfied).
80    pub feasible: bool,
81}
82
83// ---------------------------------------------------------------------------
84// Optimization result
85// ---------------------------------------------------------------------------
86
87/// Result of Bayesian optimization.
88#[derive(Debug, Clone)]
89pub struct BayesianOptResult {
90    /// Best input point found.
91    pub x_best: Array1<f64>,
92    /// Best objective function value found.
93    pub f_best: f64,
94    /// All observations in order.
95    pub observations: Vec<Observation>,
96    /// Number of function evaluations.
97    pub n_evals: usize,
98    /// History of best values found at each iteration.
99    pub best_history: Vec<f64>,
100    /// Whether the optimisation was successful.
101    pub success: bool,
102    /// Message about the optimization.
103    pub message: String,
104}
105
106// ---------------------------------------------------------------------------
107// Constraint specification
108// ---------------------------------------------------------------------------
109
110/// A constraint for constrained Bayesian optimization.
111///
112/// The constraint is satisfied when `g(x) <= 0`.
113pub struct Constraint {
114    /// Constraint function: returns a scalar value; satisfied when <= 0.
115    pub func: Box<dyn Fn(&ArrayView1<f64>) -> f64 + Send + Sync>,
116    /// Name for diagnostic purposes.
117    pub name: String,
118}
119
120// ---------------------------------------------------------------------------
121// Dimension type
122// ---------------------------------------------------------------------------
123
124/// Describes whether a dimension is continuous or integer-valued.
125#[derive(Debug, Clone, Copy, PartialEq, Eq)]
126pub enum DimensionType {
127    /// Continuous (real-valued) dimension.
128    Continuous,
129    /// Integer-valued dimension — candidates will be rounded to the nearest integer
130    /// and clamped to the declared bounds.
131    Integer,
132}
133
134// ---------------------------------------------------------------------------
135// BayesianOptimizer
136// ---------------------------------------------------------------------------
137
138/// The Bayesian optimizer.
139///
140/// Supports sequential single-objective, batch, multi-objective (ParEGO),
141/// and constrained optimization.
142pub struct BayesianOptimizer {
143    /// Search bounds: [(lower, upper), ...] for each dimension.
144    bounds: Vec<(f64, f64)>,
145    /// Configuration.
146    config: BayesianOptimizerConfig,
147    /// GP surrogate model.
148    surrogate: GpSurrogate,
149    /// Observations collected so far.
150    observations: Vec<Observation>,
151    /// Current best observation index.
152    best_idx: Option<usize>,
153    /// Constraints (empty for unconstrained).
154    constraints: Vec<Constraint>,
155    /// Random number generator.
156    rng: StdRng,
157    /// Optional per-dimension type info (Continuous or Integer).
158    /// When set, integer dimensions are rounded & clamped after every candidate generation.
159    dim_types: Option<Vec<DimensionType>>,
160}
161
162impl BayesianOptimizer {
163    /// Create a new Bayesian optimizer.
164    ///
165    /// # Arguments
166    /// * `bounds` - Search bounds for each dimension: `[(lo, hi), ...]`
167    /// * `config` - Optimizer configuration
168    pub fn new(bounds: Vec<(f64, f64)>, config: BayesianOptimizerConfig) -> OptimizeResult<Self> {
169        if bounds.is_empty() {
170            return Err(OptimizeError::InvalidInput(
171                "Bounds must have at least one dimension".to_string(),
172            ));
173        }
174        for (i, &(lo, hi)) in bounds.iter().enumerate() {
175            if lo >= hi {
176                return Err(OptimizeError::InvalidInput(format!(
177                    "Invalid bounds for dimension {}: [{}, {}]",
178                    i, lo, hi
179                )));
180            }
181        }
182
183        let seed = config.seed.unwrap_or_else(|| {
184            let s: u64 = scirs2_core::random::rng().random();
185            s
186        });
187        let rng = StdRng::seed_from_u64(seed);
188
189        let kernel: Box<dyn SurrogateKernel> = Box::new(RbfKernel::default());
190        let surrogate = GpSurrogate::new(kernel, config.gp_config.clone());
191
192        Ok(Self {
193            bounds,
194            config,
195            surrogate,
196            observations: Vec::new(),
197            best_idx: None,
198            constraints: Vec::new(),
199            rng,
200            dim_types: None,
201        })
202    }
203
204    /// Create a new optimizer with a custom kernel.
205    pub fn with_kernel(
206        bounds: Vec<(f64, f64)>,
207        kernel: Box<dyn SurrogateKernel>,
208        config: BayesianOptimizerConfig,
209    ) -> OptimizeResult<Self> {
210        let mut opt = Self::new(bounds, config)?;
211        opt.surrogate = GpSurrogate::new(kernel, opt.config.gp_config.clone());
212        Ok(opt)
213    }
214
215    /// Add a constraint: satisfied when `g(x) <= 0`.
216    pub fn add_constraint<F>(&mut self, name: &str, func: F)
217    where
218        F: Fn(&ArrayView1<f64>) -> f64 + Send + Sync + 'static,
219    {
220        self.constraints.push(Constraint {
221            func: Box::new(func),
222            name: name.to_string(),
223        });
224    }
225
226    /// Declare dimension types so integer dimensions are automatically rounded.
227    ///
228    /// The length of `types` must equal the number of dimensions (bounds).
229    pub fn set_dimension_types(&mut self, types: Vec<DimensionType>) -> OptimizeResult<()> {
230        if types.len() != self.bounds.len() {
231            return Err(OptimizeError::InvalidInput(format!(
232                "dimension_types length ({}) must match bounds length ({})",
233                types.len(),
234                self.bounds.len()
235            )));
236        }
237        self.dim_types = Some(types);
238        Ok(())
239    }
240
241    /// Round integer dimensions to the nearest integer and clamp to bounds.
242    fn enforce_dim_types(&self, x: &mut Array1<f64>) {
243        if let Some(ref types) = self.dim_types {
244            for (d, dt) in types.iter().enumerate() {
245                if *dt == DimensionType::Integer {
246                    let (lo, hi) = self.bounds[d];
247                    x[d] = x[d].round().clamp(lo, hi);
248                }
249            }
250        }
251    }
252
253    /// Warm-start from previous evaluations.
254    pub fn warm_start(&mut self, x_data: &Array2<f64>, y_data: &Array1<f64>) -> OptimizeResult<()> {
255        if x_data.nrows() != y_data.len() {
256            return Err(OptimizeError::InvalidInput(
257                "x_data and y_data row counts must match".to_string(),
258            ));
259        }
260
261        for i in 0..x_data.nrows() {
262            let obs = Observation {
263                x: x_data.row(i).to_owned(),
264                y: y_data[i],
265                constraints: Vec::new(),
266                feasible: true,
267            };
268
269            // Track best
270            match self.best_idx {
271                Some(best) if obs.y < self.observations[best].y => {
272                    self.best_idx = Some(self.observations.len());
273                }
274                None => {
275                    self.best_idx = Some(self.observations.len());
276                }
277                _ => {}
278            }
279            self.observations.push(obs);
280        }
281
282        // Fit the surrogate
283        if !self.observations.is_empty() {
284            self.fit_surrogate()?;
285        }
286
287        Ok(())
288    }
289
290    /// Run the sequential optimization loop.
291    ///
292    /// # Arguments
293    /// * `objective` - Function to minimize.
294    /// * `n_iter` - Number of iterations (function evaluations after initial design).
295    pub fn optimize<F>(&mut self, objective: F, n_iter: usize) -> OptimizeResult<BayesianOptResult>
296    where
297        F: Fn(&ArrayView1<f64>) -> f64,
298    {
299        // Phase 1: Initial design
300        let n_initial = if self.observations.is_empty() {
301            self.config.n_initial
302        } else {
303            // If warm-started, may need fewer initial points
304            self.config
305                .n_initial
306                .saturating_sub(self.observations.len())
307        };
308
309        if n_initial > 0 {
310            let sampling_config = SamplingConfig {
311                seed: Some(self.rng.random()),
312                ..Default::default()
313            };
314            let initial_points = generate_samples(
315                n_initial,
316                &self.bounds,
317                self.config.initial_design,
318                Some(sampling_config),
319            )?;
320
321            for i in 0..initial_points.nrows() {
322                let mut x = initial_points.row(i).to_owned();
323                self.enforce_dim_types(&mut x);
324                let y = objective(&x.view());
325                self.record_observation(x, y);
326            }
327
328            self.fit_surrogate()?;
329        }
330
331        let mut best_history = Vec::with_capacity(n_iter);
332        if let Some(best_idx) = self.best_idx {
333            best_history.push(self.observations[best_idx].y);
334        }
335
336        // Phase 2: Sequential optimization
337        for _iter in 0..n_iter {
338            let next_x = self.suggest_next()?;
339            let y = objective(&next_x.view());
340            self.record_observation(next_x, y);
341            self.fit_surrogate()?;
342
343            if let Some(best_idx) = self.best_idx {
344                best_history.push(self.observations[best_idx].y);
345            }
346        }
347
348        // Build result
349        let best_idx = self.best_idx.ok_or_else(|| {
350            OptimizeError::ComputationError("No observations collected".to_string())
351        })?;
352        let best_obs = &self.observations[best_idx];
353
354        Ok(BayesianOptResult {
355            x_best: best_obs.x.clone(),
356            f_best: best_obs.y,
357            observations: self.observations.clone(),
358            n_evals: self.observations.len(),
359            best_history,
360            success: true,
361            message: format!(
362                "Optimization completed: {} evaluations, best f = {:.6e}",
363                self.observations.len(),
364                best_obs.y
365            ),
366        })
367    }
368
369    /// Run batch optimization, evaluating `batch_size` points in parallel per round.
370    ///
371    /// Uses the Kriging Believer strategy: after selecting a candidate,
372    /// the GP is updated with a fantasised observation at the predicted mean.
373    pub fn optimize_batch<F>(
374        &mut self,
375        objective: F,
376        n_rounds: usize,
377        batch_size: usize,
378    ) -> OptimizeResult<BayesianOptResult>
379    where
380        F: Fn(&ArrayView1<f64>) -> f64,
381    {
382        let batch_size = batch_size.max(1);
383
384        // Phase 1: Initial design (same as sequential)
385        let n_initial = if self.observations.is_empty() {
386            self.config.n_initial
387        } else {
388            self.config
389                .n_initial
390                .saturating_sub(self.observations.len())
391        };
392
393        if n_initial > 0 {
394            let sampling_config = SamplingConfig {
395                seed: Some(self.rng.random()),
396                ..Default::default()
397            };
398            let initial_points = generate_samples(
399                n_initial,
400                &self.bounds,
401                self.config.initial_design,
402                Some(sampling_config),
403            )?;
404
405            for i in 0..initial_points.nrows() {
406                let mut x = initial_points.row(i).to_owned();
407                self.enforce_dim_types(&mut x);
408                let y = objective(&x.view());
409                self.record_observation(x, y);
410            }
411            self.fit_surrogate()?;
412        }
413
414        let mut best_history = Vec::with_capacity(n_rounds);
415        if let Some(best_idx) = self.best_idx {
416            best_history.push(self.observations[best_idx].y);
417        }
418
419        // Phase 2: Batch optimization rounds
420        for _round in 0..n_rounds {
421            let batch = self.suggest_batch(batch_size)?;
422
423            // Evaluate all batch points
424            for x in &batch {
425                let y = objective(&x.view());
426                self.record_observation(x.clone(), y);
427            }
428
429            self.fit_surrogate()?;
430
431            if let Some(best_idx) = self.best_idx {
432                best_history.push(self.observations[best_idx].y);
433            }
434        }
435
436        let best_idx = self.best_idx.ok_or_else(|| {
437            OptimizeError::ComputationError("No observations collected".to_string())
438        })?;
439        let best_obs = &self.observations[best_idx];
440
441        Ok(BayesianOptResult {
442            x_best: best_obs.x.clone(),
443            f_best: best_obs.y,
444            observations: self.observations.clone(),
445            n_evals: self.observations.len(),
446            best_history,
447            success: true,
448            message: format!(
449                "Batch optimization completed: {} evaluations, best f = {:.6e}",
450                self.observations.len(),
451                best_obs.y
452            ),
453        })
454    }
455
456    /// Multi-objective optimization via ParEGO scalarization.
457    ///
458    /// Uses random weight vectors to scalarise the objectives into a single
459    /// augmented Chebyshev function, then runs standard BO on the scalarization.
460    ///
461    /// # Arguments
462    /// * `objectives` - Vector of objective functions to minimize.
463    /// * `n_iter` - Number of sequential iterations.
464    pub fn optimize_multi_objective<F>(
465        &mut self,
466        objectives: &[F],
467        n_iter: usize,
468    ) -> OptimizeResult<BayesianOptResult>
469    where
470        F: Fn(&ArrayView1<f64>) -> f64,
471    {
472        if objectives.is_empty() {
473            return Err(OptimizeError::InvalidInput(
474                "At least one objective is required".to_string(),
475            ));
476        }
477        if objectives.len() == 1 {
478            // Single objective: delegate to standard optimize
479            return self.optimize(&objectives[0], n_iter);
480        }
481
482        let n_obj = objectives.len();
483
484        // Phase 1: Initial design
485        let n_initial = if self.observations.is_empty() {
486            self.config.n_initial
487        } else {
488            self.config
489                .n_initial
490                .saturating_sub(self.observations.len())
491        };
492
493        // Store all objective values for normalization
494        let mut all_obj_values: Vec<Vec<f64>> = vec![Vec::new(); n_obj];
495
496        if n_initial > 0 {
497            let sampling_config = SamplingConfig {
498                seed: Some(self.rng.random()),
499                ..Default::default()
500            };
501            let initial_points = generate_samples(
502                n_initial,
503                &self.bounds,
504                self.config.initial_design,
505                Some(sampling_config),
506            )?;
507
508            for i in 0..initial_points.nrows() {
509                let mut x = initial_points.row(i).to_owned();
510                self.enforce_dim_types(&mut x);
511                let obj_vals: Vec<f64> = objectives.iter().map(|f| f(&x.view())).collect();
512
513                // ParEGO scalarization with uniform weight (initial)
514                let scalarized = parego_scalarize(&obj_vals, &vec![1.0 / n_obj as f64; n_obj]);
515                self.record_observation(x, scalarized);
516
517                for (k, &v) in obj_vals.iter().enumerate() {
518                    all_obj_values[k].push(v);
519                }
520            }
521            self.fit_surrogate()?;
522        }
523
524        let mut best_history = Vec::new();
525        if let Some(best_idx) = self.best_idx {
526            best_history.push(self.observations[best_idx].y);
527        }
528
529        // Phase 2: Sequential iterations with rotating random weights
530        for _iter in 0..n_iter {
531            // Generate random weight vector on the simplex
532            let weights = random_simplex_point(n_obj, &mut self.rng);
533
534            // Suggest next point (based on current scalarized GP)
535            let next_x = self.suggest_next()?;
536
537            // Evaluate all objectives
538            let obj_vals: Vec<f64> = objectives.iter().map(|f| f(&next_x.view())).collect();
539            for (k, &v) in obj_vals.iter().enumerate() {
540                all_obj_values[k].push(v);
541            }
542
543            // Normalize and scalarize
544            let normalized: Vec<f64> = (0..n_obj)
545                .map(|k| {
546                    let vals = &all_obj_values[k];
547                    let min_v = vals.iter().copied().fold(f64::INFINITY, f64::min);
548                    let max_v = vals.iter().copied().fold(f64::NEG_INFINITY, f64::max);
549                    let range = (max_v - min_v).max(1e-12);
550                    (obj_vals[k] - min_v) / range
551                })
552                .collect();
553
554            let scalarized = parego_scalarize(&normalized, &weights);
555            self.record_observation(next_x, scalarized);
556            self.fit_surrogate()?;
557
558            if let Some(best_idx) = self.best_idx {
559                best_history.push(self.observations[best_idx].y);
560            }
561        }
562
563        let best_idx = self.best_idx.ok_or_else(|| {
564            OptimizeError::ComputationError("No observations collected".to_string())
565        })?;
566        let best_obs = &self.observations[best_idx];
567
568        Ok(BayesianOptResult {
569            x_best: best_obs.x.clone(),
570            f_best: best_obs.y,
571            observations: self.observations.clone(),
572            n_evals: self.observations.len(),
573            best_history,
574            success: true,
575            message: format!(
576                "ParEGO multi-objective optimization completed: {} evaluations",
577                self.observations.len()
578            ),
579        })
580    }
581
582    /// Get the ask interface: suggest the next point to evaluate.
583    pub fn ask(&mut self) -> OptimizeResult<Array1<f64>> {
584        if self.observations.is_empty() || self.observations.len() < self.config.n_initial {
585            // Still in initial design phase
586            let sampling_config = SamplingConfig {
587                seed: Some(self.rng.random()),
588                ..Default::default()
589            };
590            let points = generate_samples(
591                1,
592                &self.bounds,
593                self.config.initial_design,
594                Some(sampling_config),
595            )?;
596            Ok(points.row(0).to_owned())
597        } else {
598            self.suggest_next()
599        }
600    }
601
602    /// Tell interface: update with an observation.
603    pub fn tell(&mut self, x: Array1<f64>, y: f64) -> OptimizeResult<()> {
604        self.record_observation(x, y);
605        if self.observations.len() >= 2 {
606            self.fit_surrogate()?;
607        }
608        Ok(())
609    }
610
611    /// Get the current best observation.
612    pub fn best(&self) -> Option<&Observation> {
613        self.best_idx.map(|i| &self.observations[i])
614    }
615
616    /// Get all observations.
617    pub fn observations(&self) -> &[Observation] {
618        &self.observations
619    }
620
621    /// Number of observations.
622    pub fn n_observations(&self) -> usize {
623        self.observations.len()
624    }
625
626    /// Get reference to the GP surrogate.
627    pub fn surrogate(&self) -> &GpSurrogate {
628        &self.surrogate
629    }
630
631    // -----------------------------------------------------------------------
632    // Internal methods
633    // -----------------------------------------------------------------------
634
635    /// Record an observation and update the best index.
636    fn record_observation(&mut self, x: Array1<f64>, y: f64) {
637        let feasible = self.evaluate_constraints(&x);
638
639        let obs = Observation {
640            x,
641            y,
642            constraints: Vec::new(), // filled below if needed
643            feasible,
644        };
645
646        let idx = self.observations.len();
647
648        // Update best (prefer feasible solutions)
649        match self.best_idx {
650            Some(best) => {
651                let cur_best = &self.observations[best];
652                let new_is_better = if obs.feasible && !cur_best.feasible {
653                    true
654                } else if obs.feasible == cur_best.feasible {
655                    obs.y < cur_best.y
656                } else {
657                    false
658                };
659                if new_is_better {
660                    self.best_idx = Some(idx);
661                }
662            }
663            None => {
664                self.best_idx = Some(idx);
665            }
666        }
667
668        self.observations.push(obs);
669    }
670
671    /// Evaluate constraints for a point; returns true if all constraints are satisfied.
672    fn evaluate_constraints(&self, x: &Array1<f64>) -> bool {
673        self.constraints.iter().all(|c| (c.func)(&x.view()) <= 0.0)
674    }
675
676    /// Fit or refit the GP surrogate on all observations.
677    fn fit_surrogate(&mut self) -> OptimizeResult<()> {
678        let n = self.observations.len();
679        if n == 0 {
680            return Ok(());
681        }
682        let n_dims = self.observations[0].x.len();
683
684        let mut x_data = Array2::zeros((n, n_dims));
685        let mut y_data = Array1::zeros(n);
686
687        for (i, obs) in self.observations.iter().enumerate() {
688            for j in 0..n_dims {
689                x_data[[i, j]] = obs.x[j];
690            }
691            y_data[i] = obs.y;
692        }
693
694        self.surrogate.fit(&x_data, &y_data)
695    }
696
697    /// Suggest the next point to evaluate by optimising the acquisition function.
698    fn suggest_next(&mut self) -> OptimizeResult<Array1<f64>> {
699        let f_best = self.best_idx.map(|i| self.observations[i].y).unwrap_or(0.0);
700
701        // Build reference points for KG if needed
702        let n = self.observations.len();
703        let n_dims = self.bounds.len();
704        let ref_points = if n > 0 {
705            let mut pts = Array2::zeros((n, n_dims));
706            for (i, obs) in self.observations.iter().enumerate() {
707                for j in 0..n_dims {
708                    pts[[i, j]] = obs.x[j];
709                }
710            }
711            Some(pts)
712        } else {
713            None
714        };
715
716        let acq = self.config.acquisition.build(f_best, ref_points.as_ref());
717
718        self.optimize_acquisition(acq.as_ref())
719    }
720
721    /// Suggest a batch of points using the Kriging Believer strategy.
722    fn suggest_batch(&mut self, batch_size: usize) -> OptimizeResult<Vec<Array1<f64>>> {
723        let mut batch = Vec::with_capacity(batch_size);
724
725        for _ in 0..batch_size {
726            let next = self.suggest_next()?;
727
728            // Fantasy: predict mean at the selected point and add it as a phantom observation
729            let (mu, _sigma) = self.surrogate.predict_single(&next.view())?;
730            self.record_observation(next.clone(), mu);
731            self.fit_surrogate()?;
732
733            batch.push(next);
734        }
735
736        // Remove the phantom observations (they will be replaced with real ones)
737        let n_real = self.observations.len() - batch_size;
738        self.observations.truncate(n_real);
739
740        // Refit surrogate without phantoms
741        if !self.observations.is_empty() {
742            // Update best_idx in case we removed the best
743            self.best_idx = None;
744            for (i, obs) in self.observations.iter().enumerate() {
745                match self.best_idx {
746                    Some(best) if obs.y < self.observations[best].y => {
747                        self.best_idx = Some(i);
748                    }
749                    None => {
750                        self.best_idx = Some(i);
751                    }
752                    _ => {}
753                }
754            }
755            self.fit_surrogate()?;
756        }
757
758        Ok(batch)
759    }
760
761    /// Optimise the acquisition function over the search space.
762    ///
763    /// Uses random sampling + local refinement (coordinate search).
764    fn optimize_acquisition(&mut self, acq: &dyn AcquisitionFn) -> OptimizeResult<Array1<f64>> {
765        let n_dims = self.bounds.len();
766        let n_candidates = self.config.acq_n_candidates;
767        let n_restarts = self.config.acq_n_restarts;
768
769        // Generate random candidates
770        let sampling_config = SamplingConfig {
771            seed: Some(self.rng.random()),
772            ..Default::default()
773        };
774        let candidates = generate_samples(
775            n_candidates,
776            &self.bounds,
777            SamplingStrategy::Random,
778            Some(sampling_config),
779        )?;
780
781        // Also include the current best as a candidate
782        let mut best_x = candidates.row(0).to_owned();
783        let mut best_val = f64::NEG_INFINITY;
784
785        // Evaluate all candidates
786        for i in 0..candidates.nrows() {
787            match acq.evaluate(&candidates.row(i), &self.surrogate) {
788                Ok(val) if val > best_val => {
789                    best_val = val;
790                    best_x = candidates.row(i).to_owned();
791                }
792                _ => {}
793            }
794        }
795
796        // If we have a current best observation, add it as a candidate
797        if let Some(best_idx) = self.best_idx {
798            let obs_x = &self.observations[best_idx].x;
799            if let Ok(val) = acq.evaluate(&obs_x.view(), &self.surrogate) {
800                if val > best_val {
801                    best_val = val;
802                    best_x = obs_x.clone();
803                }
804            }
805        }
806
807        // Local refinement: coordinate-wise search from the top-n candidates
808        // Collect top candidates
809        let mut scored: Vec<(f64, usize)> = Vec::new();
810        for i in 0..candidates.nrows() {
811            if let Ok(val) = acq.evaluate(&candidates.row(i), &self.surrogate) {
812                scored.push((val, i));
813            }
814        }
815        scored.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
816
817        let n_refine = n_restarts.min(scored.len());
818        for k in 0..n_refine {
819            let mut x_current = candidates.row(scored[k].1).to_owned();
820            let mut f_current = scored[k].0;
821
822            // Coordinate-wise golden section search
823            for _round in 0..3 {
824                for d in 0..n_dims {
825                    let (lo, hi) = self.bounds[d];
826                    let (refined_x, refined_f) =
827                        golden_section_1d(acq, &self.surrogate, &x_current, d, lo, hi, 20)?;
828                    if refined_f > f_current {
829                        x_current[d] = refined_x;
830                        f_current = refined_f;
831                    }
832                }
833            }
834
835            if f_current > best_val {
836                best_val = f_current;
837                best_x = x_current;
838            }
839        }
840
841        // Clamp to bounds
842        for (d, &(lo, hi)) in self.bounds.iter().enumerate() {
843            best_x[d] = best_x[d].clamp(lo, hi);
844        }
845
846        // Enforce integer rounding for integer dimensions
847        self.enforce_dim_types(&mut best_x);
848
849        Ok(best_x)
850    }
851}
852
853// ---------------------------------------------------------------------------
854// Helper functions
855// ---------------------------------------------------------------------------
856
857/// Golden section search for maximising `acq(x_base with dim d = t)` over [lo, hi].
858fn golden_section_1d(
859    acq: &dyn AcquisitionFn,
860    surrogate: &GpSurrogate,
861    x_base: &Array1<f64>,
862    dim: usize,
863    lo: f64,
864    hi: f64,
865    max_iters: usize,
866) -> OptimizeResult<(f64, f64)> {
867    let gr = (5.0_f64.sqrt() - 1.0) / 2.0; // golden ratio conjugate
868    let mut a = lo;
869    let mut b = hi;
870
871    let eval_at = |t: f64| -> OptimizeResult<f64> {
872        let mut x = x_base.clone();
873        x[dim] = t;
874        acq.evaluate(&x.view(), surrogate)
875    };
876
877    let mut c = b - gr * (b - a);
878    let mut d = a + gr * (b - a);
879    let mut fc = eval_at(c)?;
880    let mut fd = eval_at(d)?;
881
882    for _ in 0..max_iters {
883        if (b - a).abs() < 1e-8 {
884            break;
885        }
886        // We want to maximise, so we keep the side with the larger value
887        if fc < fd {
888            a = c;
889            c = d;
890            fc = fd;
891            d = a + gr * (b - a);
892            fd = eval_at(d)?;
893        } else {
894            b = d;
895            d = c;
896            fd = fc;
897            c = b - gr * (b - a);
898            fc = eval_at(c)?;
899        }
900    }
901
902    let mid = (a + b) / 2.0;
903    let f_mid = eval_at(mid)?;
904    Ok((mid, f_mid))
905}
906
907/// ParEGO augmented Chebyshev scalarization.
908///
909/// s(f, w) = max_k { w_k * f_k } + rho * sum_k { w_k * f_k }
910///
911/// where rho = 0.05 is a small augmentation coefficient.
912fn parego_scalarize(obj_values: &[f64], weights: &[f64]) -> f64 {
913    let rho = 0.05;
914    let mut max_wf = f64::NEG_INFINITY;
915    let mut sum_wf = 0.0;
916
917    for (k, (&fk, &wk)) in obj_values.iter().zip(weights.iter()).enumerate() {
918        let wf = wk * fk;
919        if wf > max_wf {
920            max_wf = wf;
921        }
922        sum_wf += wf;
923    }
924
925    max_wf + rho * sum_wf
926}
927
928/// Generate a random point on the probability simplex using the Dirichlet trick.
929fn random_simplex_point(n: usize, rng: &mut StdRng) -> Vec<f64> {
930    if n == 0 {
931        return Vec::new();
932    }
933    if n == 1 {
934        return vec![1.0];
935    }
936
937    // Sample from Exp(1) and normalize
938    let mut values: Vec<f64> = (0..n)
939        .map(|_| {
940            let u: f64 = rng.random_range(1e-10..1.0);
941            -u.ln()
942        })
943        .collect();
944
945    let sum: f64 = values.iter().sum();
946    if sum > 0.0 {
947        for v in &mut values {
948            *v /= sum;
949        }
950    } else {
951        // Fallback to uniform
952        let w = 1.0 / n as f64;
953        values.fill(w);
954    }
955    values
956}
957
958// ---------------------------------------------------------------------------
959// Convenience function
960// ---------------------------------------------------------------------------
961
962/// Run Bayesian optimization on a function.
963///
964/// This is a high-level convenience function that creates a `BayesianOptimizer`,
965/// runs the optimization, and returns the result.
966///
967/// # Arguments
968/// * `objective` - Function to minimize: `f(x) -> f64`
969/// * `bounds` - Search bounds: `[(lo, hi), ...]`
970/// * `n_iter` - Number of sequential iterations (after initial design)
971/// * `config` - Optional optimizer configuration
972///
973/// # Example
974///
975/// ```rust
976/// use scirs2_optimize::bayesian::optimize;
977/// use scirs2_core::ndarray::ArrayView1;
978///
979/// let result = optimize(
980///     |x: &ArrayView1<f64>| x[0].powi(2) + x[1].powi(2),
981///     &[(-5.0, 5.0), (-5.0, 5.0)],
982///     20,
983///     None,
984/// ).expect("optimization failed");
985///
986/// assert!(result.f_best < 1.0);
987/// ```
988pub fn optimize<F>(
989    objective: F,
990    bounds: &[(f64, f64)],
991    n_iter: usize,
992    config: Option<BayesianOptimizerConfig>,
993) -> OptimizeResult<BayesianOptResult>
994where
995    F: Fn(&ArrayView1<f64>) -> f64,
996{
997    let config = config.unwrap_or_default();
998    let mut optimizer = BayesianOptimizer::new(bounds.to_vec(), config)?;
999    optimizer.optimize(objective, n_iter)
1000}
1001
1002// ---------------------------------------------------------------------------
1003// Tests
1004// ---------------------------------------------------------------------------
1005
1006#[cfg(test)]
1007mod tests {
1008    use super::*;
1009    use scirs2_core::ndarray::array;
1010
1011    fn sphere(x: &ArrayView1<f64>) -> f64 {
1012        x.iter().map(|&v| v * v).sum()
1013    }
1014
1015    fn rosenbrock_2d(x: &ArrayView1<f64>) -> f64 {
1016        (1.0 - x[0]).powi(2) + 100.0 * (x[1] - x[0].powi(2)).powi(2)
1017    }
1018
1019    #[test]
1020    fn test_optimize_sphere_2d() {
1021        let config = BayesianOptimizerConfig {
1022            n_initial: 8,
1023            seed: Some(42),
1024            gp_config: GpSurrogateConfig {
1025                optimize_hyperparams: false,
1026                noise_variance: 1e-4,
1027                ..Default::default()
1028            },
1029            ..Default::default()
1030        };
1031        let result = optimize(sphere, &[(-5.0, 5.0), (-5.0, 5.0)], 25, Some(config))
1032            .expect("optimization should succeed");
1033
1034        assert!(result.success);
1035        assert!(result.f_best < 2.0, "f_best = {:.4}", result.f_best);
1036    }
1037
1038    #[test]
1039    fn test_optimizer_ask_tell() {
1040        let config = BayesianOptimizerConfig {
1041            n_initial: 5,
1042            seed: Some(42),
1043            gp_config: GpSurrogateConfig {
1044                optimize_hyperparams: false,
1045                noise_variance: 1e-4,
1046                ..Default::default()
1047            },
1048            ..Default::default()
1049        };
1050        let mut opt =
1051            BayesianOptimizer::new(vec![(-5.0, 5.0), (-5.0, 5.0)], config).expect("create ok");
1052
1053        for _ in 0..15 {
1054            let x = opt.ask().expect("ask ok");
1055            let y = sphere(&x.view());
1056            opt.tell(x, y).expect("tell ok");
1057        }
1058
1059        let best = opt.best().expect("should have a best");
1060        assert!(best.y < 5.0, "best y = {:.4}", best.y);
1061    }
1062
1063    #[test]
1064    fn test_warm_start() {
1065        let config = BayesianOptimizerConfig {
1066            n_initial: 3,
1067            seed: Some(42),
1068            gp_config: GpSurrogateConfig {
1069                optimize_hyperparams: false,
1070                noise_variance: 1e-4,
1071                ..Default::default()
1072            },
1073            ..Default::default()
1074        };
1075        let mut opt =
1076            BayesianOptimizer::new(vec![(-5.0, 5.0), (-5.0, 5.0)], config).expect("create ok");
1077
1078        // Warm start with some previous data
1079        let x_prev =
1080            Array2::from_shape_vec((3, 2), vec![0.1, 0.2, -0.3, 0.1, 0.5, -0.5]).expect("shape ok");
1081        let y_prev = array![0.05, 0.1, 0.5];
1082        opt.warm_start(&x_prev, &y_prev).expect("warm start ok");
1083
1084        assert_eq!(opt.n_observations(), 3);
1085
1086        let result = opt.optimize(sphere, 10).expect("optimize ok");
1087        assert!(result.f_best < 0.5);
1088    }
1089
1090    #[test]
1091    fn test_batch_optimization() {
1092        let config = BayesianOptimizerConfig {
1093            n_initial: 5,
1094            seed: Some(42),
1095            gp_config: GpSurrogateConfig {
1096                optimize_hyperparams: false,
1097                noise_variance: 1e-4,
1098                ..Default::default()
1099            },
1100            ..Default::default()
1101        };
1102        let mut opt =
1103            BayesianOptimizer::new(vec![(-5.0, 5.0), (-5.0, 5.0)], config).expect("create ok");
1104
1105        let result = opt
1106            .optimize_batch(sphere, 5, 3)
1107            .expect("batch optimization ok");
1108        assert!(result.success);
1109        // 5 initial + 5*3 = 20 total evaluations
1110        assert_eq!(result.n_evals, 20);
1111    }
1112
1113    #[test]
1114    fn test_constrained_optimization() {
1115        let config = BayesianOptimizerConfig {
1116            n_initial: 8,
1117            seed: Some(42),
1118            gp_config: GpSurrogateConfig {
1119                optimize_hyperparams: false,
1120                noise_variance: 1e-4,
1121                ..Default::default()
1122            },
1123            ..Default::default()
1124        };
1125        let mut opt =
1126            BayesianOptimizer::new(vec![(-5.0, 5.0), (-5.0, 5.0)], config).expect("create ok");
1127
1128        // Constraint: x[0] >= 1.0 (i.e., 1.0 - x[0] <= 0)
1129        opt.add_constraint("x0_ge_1", |x: &ArrayView1<f64>| 1.0 - x[0]);
1130
1131        let result = opt.optimize(sphere, 20).expect("optimize ok");
1132        // The constrained minimum of x^2+y^2 with x >= 1 is at (1,0), f=1
1133        // We just check the optimizer found something feasible and reasonable
1134        assert!(result.success);
1135        assert!(result.x_best[0] >= 0.5, "x[0] should be near >= 1");
1136    }
1137
1138    #[test]
1139    fn test_multi_objective_parego() {
1140        let config = BayesianOptimizerConfig {
1141            n_initial: 8,
1142            seed: Some(42),
1143            gp_config: GpSurrogateConfig {
1144                optimize_hyperparams: false,
1145                noise_variance: 1e-4,
1146                ..Default::default()
1147            },
1148            ..Default::default()
1149        };
1150        let mut opt =
1151            BayesianOptimizer::new(vec![(-5.0, 5.0), (-5.0, 5.0)], config).expect("create ok");
1152
1153        // Two objectives: f1 = (x-1)^2 + y^2, f2 = (x+1)^2 + y^2
1154        let f1 = |x: &ArrayView1<f64>| (x[0] - 1.0).powi(2) + x[1].powi(2);
1155        let f2 = |x: &ArrayView1<f64>| (x[0] + 1.0).powi(2) + x[1].powi(2);
1156        let objectives: Vec<Box<dyn Fn(&ArrayView1<f64>) -> f64>> =
1157            vec![Box::new(f1), Box::new(f2)];
1158
1159        let obj_refs: Vec<&dyn Fn(&ArrayView1<f64>) -> f64> = objectives
1160            .iter()
1161            .map(|f| f.as_ref() as &dyn Fn(&ArrayView1<f64>) -> f64)
1162            .collect();
1163
1164        // Need to pass as slice of Fn
1165        let result = opt
1166            .optimize_multi_objective(&obj_refs[..], 15)
1167            .expect("multi-objective ok");
1168        assert!(result.success);
1169        // The Pareto front is between x=-1 and x=1
1170        assert!(result.x_best[0].abs() <= 5.0);
1171    }
1172
1173    #[test]
1174    fn test_different_acquisition_functions() {
1175        let bounds = vec![(-3.0, 3.0)];
1176
1177        for acq in &[
1178            AcquisitionType::EI { xi: 0.01 },
1179            AcquisitionType::PI { xi: 0.01 },
1180            AcquisitionType::UCB { kappa: 2.0 },
1181            AcquisitionType::Thompson { seed: 42 },
1182        ] {
1183            let config = BayesianOptimizerConfig {
1184                acquisition: acq.clone(),
1185                n_initial: 5,
1186                seed: Some(42),
1187                gp_config: GpSurrogateConfig {
1188                    optimize_hyperparams: false,
1189                    noise_variance: 1e-4,
1190                    ..Default::default()
1191                },
1192                ..Default::default()
1193            };
1194            let result = optimize(
1195                |x: &ArrayView1<f64>| x[0].powi(2),
1196                &bounds,
1197                10,
1198                Some(config),
1199            )
1200            .expect("optimize ok");
1201            assert!(
1202                result.f_best < 3.0,
1203                "Acquisition {:?} failed: f_best = {}",
1204                acq,
1205                result.f_best
1206            );
1207        }
1208    }
1209
1210    #[test]
1211    fn test_invalid_bounds_rejected() {
1212        let result = BayesianOptimizer::new(
1213            vec![(5.0, 1.0)], // lo > hi
1214            BayesianOptimizerConfig::default(),
1215        );
1216        assert!(result.is_err());
1217    }
1218
1219    #[test]
1220    fn test_empty_bounds_rejected() {
1221        let result = BayesianOptimizer::new(vec![], BayesianOptimizerConfig::default());
1222        assert!(result.is_err());
1223    }
1224
1225    #[test]
1226    fn test_best_history_monotonic() {
1227        let config = BayesianOptimizerConfig {
1228            n_initial: 5,
1229            seed: Some(42),
1230            gp_config: GpSurrogateConfig {
1231                optimize_hyperparams: false,
1232                noise_variance: 1e-4,
1233                ..Default::default()
1234            },
1235            ..Default::default()
1236        };
1237        let result =
1238            optimize(sphere, &[(-5.0, 5.0), (-5.0, 5.0)], 10, Some(config)).expect("optimize ok");
1239
1240        // Best history should be non-increasing
1241        for i in 1..result.best_history.len() {
1242            assert!(
1243                result.best_history[i] <= result.best_history[i - 1] + 1e-12,
1244                "Best history not monotonic at index {}: {} > {}",
1245                i,
1246                result.best_history[i],
1247                result.best_history[i - 1]
1248            );
1249        }
1250    }
1251
1252    #[test]
1253    fn test_parego_scalarize() {
1254        let obj = [0.3, 0.7];
1255        let w = [0.5, 0.5];
1256        let s = parego_scalarize(&obj, &w);
1257        // max(0.15, 0.35) + 0.05 * (0.15 + 0.35) = 0.35 + 0.025 = 0.375
1258        assert!((s - 0.375).abs() < 1e-10);
1259    }
1260
1261    #[test]
1262    fn test_random_simplex_point_sums_to_one() {
1263        let mut rng = StdRng::seed_from_u64(42);
1264        for n in 1..6 {
1265            let pt = random_simplex_point(n, &mut rng);
1266            assert_eq!(pt.len(), n);
1267            let sum: f64 = pt.iter().sum();
1268            assert!((sum - 1.0).abs() < 1e-10, "Simplex sum = {}", sum);
1269            for &v in &pt {
1270                assert!(v >= 0.0, "Simplex component negative: {}", v);
1271            }
1272        }
1273    }
1274
1275    #[test]
1276    fn test_optimize_1d() {
1277        let config = BayesianOptimizerConfig {
1278            n_initial: 5,
1279            seed: Some(42),
1280            gp_config: GpSurrogateConfig {
1281                optimize_hyperparams: false,
1282                noise_variance: 1e-4,
1283                ..Default::default()
1284            },
1285            ..Default::default()
1286        };
1287        let result = optimize(
1288            |x: &ArrayView1<f64>| (x[0] - 2.0).powi(2),
1289            &[(-5.0, 5.0)],
1290            15,
1291            Some(config),
1292        )
1293        .expect("optimize ok");
1294
1295        assert!(
1296            (result.x_best[0] - 2.0).abs() < 1.5,
1297            "x_best = {:.4}, expected ~2.0",
1298            result.x_best[0]
1299        );
1300        assert!(result.f_best < 2.0);
1301    }
1302
1303    #[test]
1304    fn test_integer_dimension_enforcement() {
1305        // Define 1D integer search space: x in {0, 1, 2, 3}
1306        let bounds = vec![(0.0, 3.0)];
1307        let config = BayesianOptimizerConfig {
1308            n_initial: 4,
1309            seed: Some(42),
1310            acq_n_candidates: 50,
1311            ..Default::default()
1312        };
1313        let mut opt = BayesianOptimizer::new(bounds, config).expect("Failed to create optimizer");
1314        opt.set_dimension_types(vec![DimensionType::Integer])
1315            .expect("Failed to set dim types");
1316
1317        // Objective: f(x) = (x - 2)^2, minimum at x=2
1318        let result = opt
1319            .optimize(
1320                |x| {
1321                    let v = x[0];
1322                    (v - 2.0).powi(2)
1323                },
1324                6,
1325            )
1326            .expect("Optimization failed");
1327
1328        // All evaluated points must be integers in [0, 3]
1329        for obs in &result.observations {
1330            let v = obs.x[0];
1331            assert!(v >= 0.0 && v <= 3.0, "Out of bounds: {}", v);
1332            assert!((v - v.round()).abs() < 1e-12, "Not integer: {}", v);
1333        }
1334
1335        // Best should be x=2
1336        assert!(
1337            (result.x_best[0] - 2.0).abs() < 1e-12,
1338            "Best x should be 2, got {}",
1339            result.x_best[0]
1340        );
1341    }
1342
1343    #[test]
1344    fn test_set_dimension_types_length_mismatch() {
1345        let bounds = vec![(0.0, 5.0), (0.0, 5.0)];
1346        let config = BayesianOptimizerConfig::default();
1347        let mut opt = BayesianOptimizer::new(bounds, config).expect("Failed to create optimizer");
1348        // Wrong length should error
1349        let result = opt.set_dimension_types(vec![DimensionType::Integer]);
1350        assert!(result.is_err());
1351    }
1352}