Skip to main content

scirs2_optimize/bayesian/
warm_start.rs

1//! Warm-starting and transfer learning for Bayesian Optimization.
2//!
3//! Provides mechanisms to seed a new Bayesian optimization run with knowledge
4//! from previous runs, related tasks, or meta-learned initialization strategies.
5//!
6//! # Strategies
7//!
8//! 1. **Direct warm-start**: inject prior observations directly into the GP.
9//! 2. **Scaled transfer**: align source observations to the target domain via
10//!    min-max rescaling and inject them with a down-weighted noise level.
11//! 3. **Multi-task BO**: maintain separate GPs per task and combine acquisition
12//!    values using task-similarity weights.
13//! 4. **Meta-learning**: estimate good GP hyperparameter initialization from
14//!    observed task features (warm-starting the surrogate model itself).
15//!
16//! # Example
17//!
18//! ```rust
19//! use scirs2_optimize::bayesian::warm_start::{
20//!     WarmStartBo, WarmStartConfig, PriorRun, MetaLearner,
21//! };
22//! use scirs2_core::ndarray::{Array1, Array2};
23//!
24//! // A previous run on a related problem:
25//! let prior = PriorRun {
26//!     x: Array2::from_shape_vec((4, 1), vec![0.0, 1.0, 2.0, 3.0]).expect("shape"),
27//!     y: Array1::from_vec(vec![1.0, 0.5, 0.1, 0.6]),
28//!     bounds: vec![(0.0_f64, 4.0_f64)],
29//!     weight: 0.8,
30//! };
31//!
32//! let config = WarmStartConfig {
33//!     prior_runs: vec![prior],
34//!     n_initial: 3,
35//!     seed: Some(42),
36//!     ..Default::default()
37//! };
38//!
39//! let mut bo = WarmStartBo::new(vec![(0.0_f64, 4.0_f64)], config).expect("create");
40//! let result = bo.optimize(|x: &[f64]| (x[0] - 1.5_f64).powi(2), 10).expect("opt");
41//! println!("Best x: {:?}  f: {:.4}", result.x_best, result.f_best);
42//! ```
43
44use scirs2_core::ndarray::{Array1, Array2};
45use scirs2_core::random::rngs::StdRng;
46use scirs2_core::random::{Rng, RngExt, SeedableRng};
47
48use crate::error::{OptimizeError, OptimizeResult};
49
50use super::acquisition::{AcquisitionFn, AcquisitionType, ExpectedImprovement};
51use super::gp::{GpSurrogate, GpSurrogateConfig, RbfKernel};
52use super::sampling::{generate_samples, SamplingStrategy};
53
54// ---------------------------------------------------------------------------
55// Data structures
56// ---------------------------------------------------------------------------
57
58/// A record of a previous optimization run that can be used to warm-start
59/// a new run.
60#[derive(Debug, Clone)]
61pub struct PriorRun {
62    /// Input matrix from the prior run (n_obs × n_dims).
63    pub x: Array2<f64>,
64    /// Observed objective values (n_obs,).
65    pub y: Array1<f64>,
66    /// Bounds of the prior run's search space.
67    pub bounds: Vec<(f64, f64)>,
68    /// Relative weight in [0, 1] for blending with the target task.
69    /// 1.0 = full trust, 0.0 = completely ignore.
70    pub weight: f64,
71}
72
73/// Strategy for blending prior observations into the current run.
74#[derive(Debug, Clone, Copy, PartialEq)]
75pub enum BlendStrategy {
76    /// Inject prior observations with increased noise variance to down-weight them.
77    NoisyInjection {
78        /// Multiplier applied to the base noise variance for prior points.
79        noise_multiplier: f64,
80    },
81    /// Rescale prior y-values to the current run's expected range and inject.
82    RescaleAndInject,
83    /// Only use prior runs to warm-start GP hyperparameters, not data.
84    HyperparamOnly,
85    /// Use a weighted combination of independent GP predictions.
86    WeightedEnsemble,
87}
88
89impl Default for BlendStrategy {
90    fn default() -> Self {
91        Self::NoisyInjection {
92            noise_multiplier: 10.0,
93        }
94    }
95}
96
97/// Configuration for warm-start Bayesian optimization.
98#[derive(Clone)]
99pub struct WarmStartConfig {
100    /// Prior runs to use for warm-starting.
101    pub prior_runs: Vec<PriorRun>,
102    /// Strategy for blending prior data.
103    pub blend_strategy: BlendStrategy,
104    /// Number of initial random points to evaluate on the target task before
105    /// switching to BO (in addition to injected prior data).
106    pub n_initial: usize,
107    /// Acquisition function to use.
108    pub acquisition: AcquisitionType,
109    /// Seed for reproducibility.
110    pub seed: Option<u64>,
111    /// Number of candidates evaluated per acquisition optimization step.
112    pub acq_n_candidates: usize,
113    /// Verbose output level (0 = silent).
114    pub verbose: usize,
115}
116
117impl Default for WarmStartConfig {
118    fn default() -> Self {
119        Self {
120            prior_runs: Vec::new(),
121            blend_strategy: BlendStrategy::default(),
122            n_initial: 5,
123            acquisition: AcquisitionType::EI { xi: 0.01 },
124            seed: None,
125            acq_n_candidates: 200,
126            verbose: 0,
127        }
128    }
129}
130
131/// A single observation recorded during optimization.
132#[derive(Debug, Clone)]
133pub struct WarmStartObs {
134    pub x: Array1<f64>,
135    pub y: f64,
136}
137
138/// Result of a warm-start Bayesian optimization run.
139#[derive(Debug, Clone)]
140pub struct WarmStartResult {
141    /// Best input point found on the *target* task.
142    pub x_best: Array1<f64>,
143    /// Best objective value found on the target task.
144    pub f_best: f64,
145    /// All target-task observations in evaluation order.
146    pub observations: Vec<WarmStartObs>,
147    /// Number of target-task function evaluations.
148    pub n_evals: usize,
149    /// Best-value history across iterations.
150    pub best_history: Vec<f64>,
151}
152
153// ---------------------------------------------------------------------------
154// Meta-learner: warm-start GP hyperparameters
155// ---------------------------------------------------------------------------
156
157/// Features extracted from a prior run for meta-learning.
158#[derive(Debug, Clone)]
159struct TaskFeatures {
160    /// Mean of observed y-values.
161    y_mean: f64,
162    /// Std-dev of observed y-values.
163    y_std: f64,
164    /// Median pairwise input distance (proxy for scale).
165    median_dist: f64,
166    /// Optimum-to-range ratio (how well-conditioned the optimum is).
167    opt_ratio: f64,
168}
169
170impl TaskFeatures {
171    fn from_run(run: &PriorRun) -> Self {
172        let n = run.y.len();
173        let y_mean = run.y.iter().copied().sum::<f64>() / n.max(1) as f64;
174        let y_var = run.y.iter().map(|&v| (v - y_mean).powi(2)).sum::<f64>() / n.max(1) as f64;
175        let y_std = y_var.sqrt().max(1e-10);
176        let y_min = run.y.iter().copied().fold(f64::INFINITY, f64::min);
177        let y_max = run.y.iter().copied().fold(f64::NEG_INFINITY, f64::max);
178        let y_range = (y_max - y_min).max(1e-10);
179        let opt_ratio = (y_mean - y_min) / y_range;
180
181        // Compute a subset of pairwise distances for efficiency.
182        let mut dists = Vec::new();
183        let n_sub = n.min(20);
184        for i in 0..n_sub {
185            for j in (i + 1)..n_sub {
186                let row_i = run.x.row(i);
187                let row_j = run.x.row(j);
188                let sq_d: f64 = row_i
189                    .iter()
190                    .zip(row_j.iter())
191                    .map(|(a, b)| (a - b).powi(2))
192                    .sum();
193                dists.push(sq_d.sqrt());
194            }
195        }
196        dists.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
197        let median_dist = if dists.is_empty() {
198            1.0
199        } else {
200            dists[dists.len() / 2]
201        };
202
203        Self {
204            y_mean,
205            y_std,
206            median_dist: median_dist.max(1e-10),
207            opt_ratio,
208        }
209    }
210}
211
212/// Meta-learner that estimates good initial GP hyperparameters from prior tasks.
213///
214/// Uses a simple similarity-weighted average of per-task features.
215#[derive(Debug, Clone)]
216pub struct MetaLearner {
217    task_features: Vec<TaskFeatures>,
218    task_weights: Vec<f64>,
219}
220
221impl MetaLearner {
222    /// Create a meta-learner from a set of prior runs.
223    pub fn from_runs(runs: &[PriorRun]) -> Self {
224        let task_features: Vec<_> = runs.iter().map(TaskFeatures::from_run).collect();
225        let task_weights: Vec<_> = runs.iter().map(|r| r.weight.max(0.0)).collect();
226        Self {
227            task_features,
228            task_weights,
229        }
230    }
231
232    /// Suggest initial GP hyperparameters for a target task.
233    ///
234    /// Returns `(length_scale, signal_variance, noise_variance)`.
235    pub fn suggest_hyperparams(&self, target_bounds: &[(f64, f64)]) -> (f64, f64, f64) {
236        if self.task_features.is_empty() {
237            return (1.0, 1.0, 1e-4);
238        }
239
240        let total_weight: f64 = self.task_weights.iter().sum::<f64>().max(1e-10);
241
242        // Suggest length scale as fraction of the target domain diameter.
243        let domain_diameter: f64 = target_bounds
244            .iter()
245            .map(|(lo, hi)| (hi - lo).powi(2))
246            .sum::<f64>()
247            .sqrt()
248            .max(1e-10);
249
250        let mut weighted_ls = 0.0_f64;
251        let mut weighted_sv = 0.0_f64;
252
253        for (feat, &w) in self.task_features.iter().zip(self.task_weights.iter()) {
254            // Rescale the task's length scale to the target domain.
255            let rel_ls = feat.median_dist / domain_diameter;
256            weighted_ls += w * rel_ls;
257            weighted_sv += w * feat.y_std * feat.y_std;
258        }
259
260        let ls = (weighted_ls / total_weight * domain_diameter).max(1e-3);
261        let sv = (weighted_sv / total_weight).max(1e-6);
262        let noise_var = sv * 1e-3;
263
264        (ls, sv, noise_var)
265    }
266
267    /// Compute similarity between a prior task and a new target task.
268    pub fn task_similarity(prior: &PriorRun, target_bounds: &[(f64, f64)]) -> f64 {
269        if prior.x.is_empty() {
270            return 0.0;
271        }
272        let ndim = target_bounds.len().min(prior.x.ncols());
273        let n = prior.x.nrows();
274        let mut score = 0.0_f64;
275        for i in 0..n {
276            let row = prior.x.row(i);
277            let mut in_bounds = true;
278            let mut centrality = 0.0_f64;
279            for d in 0..ndim {
280                let (lo, hi) = target_bounds[d];
281                let range = (hi - lo).max(1e-10);
282                let v = row[d];
283                if v < lo || v > hi {
284                    in_bounds = false;
285                    break;
286                }
287                // Centrality: 1.0 at center, 0.0 at boundary.
288                let rel = (v - lo) / range;
289                centrality += 1.0 - (2.0 * rel - 1.0).abs();
290            }
291            if in_bounds {
292                score += centrality / ndim as f64;
293            }
294        }
295        (score / n as f64).min(1.0)
296    }
297}
298
299// ---------------------------------------------------------------------------
300// Warm-start BO
301// ---------------------------------------------------------------------------
302
303/// Bayesian optimizer with warm-starting from prior runs.
304pub struct WarmStartBo {
305    bounds: Vec<(f64, f64)>,
306    config: WarmStartConfig,
307    surrogate: GpSurrogate,
308    observations: Vec<WarmStartObs>,
309    rng: StdRng,
310    f_best: f64,
311    best_history: Vec<f64>,
312    /// Base noise variance used by the primary surrogate.
313    base_noise_variance: f64,
314    /// Per-task surrogate GPs for WeightedEnsemble mode.
315    ensemble_surrogates: Vec<(GpSurrogate, f64)>,
316}
317
318impl WarmStartBo {
319    /// Create a new warm-start BO instance.
320    pub fn new(bounds: Vec<(f64, f64)>, config: WarmStartConfig) -> OptimizeResult<Self> {
321        if bounds.is_empty() {
322            return Err(OptimizeError::InvalidInput(
323                "bounds must not be empty".into(),
324            ));
325        }
326
327        let seed = config.seed.unwrap_or(0);
328        let rng = StdRng::seed_from_u64(seed);
329
330        // Optionally warm-start hyperparameters via meta-learning.
331        let meta = MetaLearner::from_runs(&config.prior_runs);
332        let (init_ls, init_sv, init_noise) = meta.suggest_hyperparams(&bounds);
333
334        let gp_config = GpSurrogateConfig {
335            noise_variance: init_noise,
336            optimize_hyperparams: true,
337            ..Default::default()
338        };
339
340        let mut kernel = RbfKernel::new(init_sv, init_ls);
341        kernel.length_scale = init_ls;
342        kernel.signal_variance = init_sv;
343
344        let surrogate = GpSurrogate::new(Box::new(kernel), gp_config);
345
346        // Build ensemble surrogates for each prior run.
347        let ensemble_surrogates =
348            if matches!(config.blend_strategy, BlendStrategy::WeightedEnsemble) {
349                let mut ensemble = Vec::new();
350                for run in &config.prior_runs {
351                    if run.x.nrows() < 2 {
352                        continue;
353                    }
354                    let mut gp = GpSurrogate::new(
355                        Box::new(RbfKernel::default()),
356                        GpSurrogateConfig {
357                            noise_variance: 1e-4,
358                            optimize_hyperparams: false,
359                            ..Default::default()
360                        },
361                    );
362                    if gp.fit(&run.x, &run.y).is_ok() {
363                        ensemble.push((gp, run.weight));
364                    }
365                }
366                ensemble
367            } else {
368                Vec::new()
369            };
370
371        Ok(Self {
372            bounds,
373            config,
374            surrogate,
375            observations: Vec::new(),
376            rng,
377            f_best: f64::INFINITY,
378            best_history: Vec::new(),
379            base_noise_variance: init_noise,
380            ensemble_surrogates,
381        })
382    }
383
384    /// Inject prior observations into the surrogate according to the blend strategy.
385    fn inject_prior_data(&mut self) -> OptimizeResult<()> {
386        match self.config.blend_strategy {
387            BlendStrategy::HyperparamOnly => {
388                // Nothing to inject; hyperparams were already set in `new`.
389                Ok(())
390            }
391            BlendStrategy::WeightedEnsemble => {
392                // Ensemble surrogates are built in `new`; no injection needed.
393                Ok(())
394            }
395            BlendStrategy::NoisyInjection { noise_multiplier } => {
396                let ndim = self.bounds.len();
397                let mut all_x_rows = Vec::new();
398                let mut all_y = Vec::new();
399
400                for run in &self.config.prior_runs {
401                    if run.x.ncols() != ndim || run.x.is_empty() {
402                        continue;
403                    }
404                    let n = run.x.nrows();
405                    for i in 0..n {
406                        let row = run.x.row(i);
407                        // Check that the point is within target bounds.
408                        let in_domain = row
409                            .iter()
410                            .zip(self.bounds.iter())
411                            .all(|(&v, &(lo, hi))| v >= lo && v <= hi);
412                        if in_domain {
413                            all_x_rows.extend(row.iter().copied());
414                            all_y.push(run.y[i]);
415                        }
416                    }
417                }
418
419                if all_x_rows.is_empty() {
420                    return Ok(());
421                }
422
423                let n_prior = all_y.len();
424                let x_prior = Array2::from_shape_vec((n_prior, ndim), all_x_rows)
425                    .map_err(|e| OptimizeError::ComputationError(format!("shape error: {}", e)))?;
426                let y_prior = Array1::from_vec(all_y);
427
428                // Use higher noise for prior points to down-weight them.
429                // We use a separate GP with noisy config for the prior, then rebuild
430                // the main surrogate with normal noise once target data arrives.
431                let prior_noise = self.base_noise_variance * noise_multiplier;
432                let noisy_config = GpSurrogateConfig {
433                    noise_variance: prior_noise,
434                    optimize_hyperparams: false,
435                    ..Default::default()
436                };
437                let new_surrogate =
438                    GpSurrogate::new(self.surrogate.kernel().clone_box(), noisy_config);
439                self.surrogate = new_surrogate;
440                self.surrogate.fit(&x_prior, &y_prior)?;
441
442                // Restore noise level for target data.
443                Ok(())
444            }
445            BlendStrategy::RescaleAndInject => {
446                let ndim = self.bounds.len();
447                let mut all_x_rows = Vec::new();
448                let mut all_y = Vec::new();
449
450                for run in &self.config.prior_runs {
451                    if run.x.ncols() != ndim || run.x.is_empty() {
452                        continue;
453                    }
454                    let n = run.x.nrows();
455
456                    // Compute y range in the prior run.
457                    let y_min = run.y.iter().copied().fold(f64::INFINITY, f64::min);
458                    let y_max = run.y.iter().copied().fold(f64::NEG_INFINITY, f64::max);
459                    let y_range = (y_max - y_min).max(1e-10);
460
461                    for i in 0..n {
462                        let row = run.x.row(i);
463                        // Rescale x from prior bounds to target bounds.
464                        let rescaled: Vec<f64> = row
465                            .iter()
466                            .zip(run.bounds.iter().zip(self.bounds.iter()))
467                            .map(|(&v, (&(s_lo, s_hi), &(t_lo, t_hi)))| {
468                                let s_range = (s_hi - s_lo).max(1e-10);
469                                let t_range = t_hi - t_lo;
470                                t_lo + (v - s_lo) / s_range * t_range
471                            })
472                            .collect();
473
474                        // Clip to target bounds.
475                        let in_domain = rescaled
476                            .iter()
477                            .zip(self.bounds.iter())
478                            .all(|(&v, &(lo, hi))| v >= lo && v <= hi);
479
480                        if in_domain {
481                            all_x_rows.extend(rescaled);
482                            // Rescale y to [0, 1] range (normalized).
483                            let y_rescaled = (run.y[i] - y_min) / y_range;
484                            all_y.push(y_rescaled);
485                        }
486                    }
487                }
488
489                if all_x_rows.is_empty() {
490                    return Ok(());
491                }
492
493                let n_prior = all_y.len();
494                let x_prior = Array2::from_shape_vec((n_prior, ndim), all_x_rows)
495                    .map_err(|e| OptimizeError::ComputationError(format!("shape error: {}", e)))?;
496                let y_prior = Array1::from_vec(all_y);
497
498                self.surrogate.fit(&x_prior, &y_prior)?;
499                Ok(())
500            }
501        }
502    }
503
504    /// Suggest the next point to evaluate.
505    pub fn ask(&mut self) -> OptimizeResult<Vec<f64>> {
506        let ndim = self.bounds.len();
507
508        // If we don't yet have enough target observations, return a random point.
509        if self.observations.len() < self.config.n_initial {
510            let x: Vec<f64> = self
511                .bounds
512                .iter()
513                .map(|&(lo, hi)| lo + self.rng.random::<f64>() * (hi - lo))
514                .collect();
515            return Ok(x);
516        }
517
518        // Optimise the acquisition function via random search.
519        let candidates = generate_samples(
520            self.config.acq_n_candidates,
521            &self.bounds,
522            SamplingStrategy::LatinHypercube,
523            None,
524        )?;
525
526        let acquisition: Box<dyn AcquisitionFn> = self.config.acquisition.build(self.f_best, None);
527
528        let mut best_acq = f64::NEG_INFINITY;
529        let mut best_x = candidates.row(0).to_vec();
530
531        for i in 0..candidates.nrows() {
532            let row = candidates.row(i);
533            let val = if matches!(self.config.blend_strategy, BlendStrategy::WeightedEnsemble)
534                && !self.ensemble_surrogates.is_empty()
535            {
536                // Blend: weighted average of acquisition values from each surrogate.
537                let target_val = if self.surrogate.n_train() > 0 {
538                    acquisition
539                        .evaluate(&row, &self.surrogate)
540                        .unwrap_or(f64::NEG_INFINITY)
541                } else {
542                    0.0
543                };
544
545                let total_weight: f64 = self
546                    .ensemble_surrogates
547                    .iter()
548                    .map(|(_, w)| *w)
549                    .sum::<f64>()
550                    + 1.0;
551
552                let mut blended = target_val;
553                for (gp, w) in &self.ensemble_surrogates {
554                    let acq_val = acquisition.evaluate(&row, gp).unwrap_or(f64::NEG_INFINITY);
555                    blended += w * acq_val;
556                }
557                blended / total_weight
558            } else {
559                acquisition
560                    .evaluate(&row, &self.surrogate)
561                    .unwrap_or(f64::NEG_INFINITY)
562            };
563
564            if val > best_acq {
565                best_acq = val;
566                best_x = row.to_vec();
567            }
568        }
569
570        Ok(best_x)
571    }
572
573    /// Record an observation of the objective at point `x` with value `y`.
574    pub fn tell(&mut self, x: Vec<f64>, y: f64) -> OptimizeResult<()> {
575        let ndim = self.bounds.len();
576        if x.len() != ndim {
577            return Err(OptimizeError::InvalidInput(format!(
578                "x has {} dims but bounds has {}",
579                x.len(),
580                ndim
581            )));
582        }
583
584        if y < self.f_best {
585            self.f_best = y;
586        }
587        self.best_history.push(self.f_best);
588        self.observations.push(WarmStartObs {
589            x: Array1::from_vec(x.clone()),
590            y,
591        });
592
593        // Refit the surrogate on all target observations (+ any injected prior data).
594        let n = self.observations.len();
595        let mut x_rows = Vec::with_capacity(n * ndim);
596        let mut y_vec = Vec::with_capacity(n);
597        for obs in &self.observations {
598            x_rows.extend(obs.x.iter().copied());
599            y_vec.push(obs.y);
600        }
601        let x_mat = Array2::from_shape_vec((n, ndim), x_rows)
602            .map_err(|e| OptimizeError::ComputationError(format!("shape: {}", e)))?;
603        let y_arr = Array1::from_vec(y_vec);
604        self.surrogate.fit(&x_mat, &y_arr)?;
605
606        Ok(())
607    }
608
609    /// Run the full optimization loop.
610    pub fn optimize<F>(
611        &mut self,
612        mut objective: F,
613        n_calls: usize,
614    ) -> OptimizeResult<WarmStartResult>
615    where
616        F: FnMut(&[f64]) -> f64,
617    {
618        // Inject prior data before starting the target-task evaluations.
619        self.inject_prior_data()?;
620
621        for iter in 0..n_calls {
622            let x = self.ask()?;
623            let y = objective(&x);
624
625            if self.config.verbose >= 2 {
626                println!("[WarmStartBo iter {}] x={:?} y={:.6}", iter, x, y);
627            }
628
629            self.tell(x, y)?;
630        }
631
632        if self.config.verbose >= 1 {
633            println!(
634                "[WarmStartBo] Done. Best f={:.6} after {} evals",
635                self.f_best,
636                self.observations.len()
637            );
638        }
639
640        let (x_best, f_best) = self
641            .observations
642            .iter()
643            .min_by(|a, b| a.y.partial_cmp(&b.y).unwrap_or(std::cmp::Ordering::Equal))
644            .map(|o| (o.x.clone(), o.y))
645            .ok_or_else(|| OptimizeError::ComputationError("No observations".into()))?;
646
647        Ok(WarmStartResult {
648            x_best,
649            f_best,
650            observations: self.observations.clone(),
651            n_evals: self.observations.len(),
652            best_history: self.best_history.clone(),
653        })
654    }
655
656    /// Access the current best known value.
657    pub fn best_value(&self) -> f64 {
658        self.f_best
659    }
660
661    /// Access all target-task observations.
662    pub fn observations(&self) -> &[WarmStartObs] {
663        &self.observations
664    }
665}
666
667// ---------------------------------------------------------------------------
668// Multi-task BO
669// ---------------------------------------------------------------------------
670
671/// A task descriptor for multi-task BO.
672#[derive(Debug, Clone)]
673pub struct Task {
674    /// Human-readable identifier.
675    pub name: String,
676    /// Search space bounds.
677    pub bounds: Vec<(f64, f64)>,
678    /// Prior observations (may be empty for the target task).
679    pub observations_x: Array2<f64>,
680    pub observations_y: Array1<f64>,
681}
682
683/// Configuration for multi-task Bayesian optimization.
684#[derive(Clone)]
685pub struct MultiTaskBoConfig {
686    /// Index of the target task in the task list.
687    pub target_task_idx: usize,
688    /// Maximum number of evaluations on the target task.
689    pub n_calls: usize,
690    /// Number of initial random evaluations on the target task.
691    pub n_initial: usize,
692    /// Seed for reproducibility.
693    pub seed: Option<u64>,
694    /// Candidates per acquisition optimization step.
695    pub acq_n_candidates: usize,
696    /// Temperature for task-similarity softmax weighting.
697    pub similarity_temperature: f64,
698}
699
700impl Default for MultiTaskBoConfig {
701    fn default() -> Self {
702        Self {
703            target_task_idx: 0,
704            n_calls: 20,
705            n_initial: 5,
706            seed: None,
707            acq_n_candidates: 200,
708            similarity_temperature: 1.0,
709        }
710    }
711}
712
713/// Multi-task Bayesian optimizer.
714///
715/// Maintains one GP surrogate per task and combines acquisition values
716/// using task-similarity weights, boosting sample efficiency on the target task.
717pub struct MultiTaskBo {
718    tasks: Vec<Task>,
719    config: MultiTaskBoConfig,
720    /// GP surrogate per task.
721    surrogates: Vec<GpSurrogate>,
722    rng: StdRng,
723    f_best: f64,
724    target_obs: Vec<WarmStartObs>,
725    best_history: Vec<f64>,
726    /// Similarity weights for each source task.
727    task_weights: Vec<f64>,
728}
729
730impl MultiTaskBo {
731    /// Create a new multi-task BO instance.
732    pub fn new(tasks: Vec<Task>, config: MultiTaskBoConfig) -> OptimizeResult<Self> {
733        if tasks.is_empty() {
734            return Err(OptimizeError::InvalidInput(
735                "tasks must not be empty".into(),
736            ));
737        }
738        if config.target_task_idx >= tasks.len() {
739            return Err(OptimizeError::InvalidInput(format!(
740                "target_task_idx {} out of range ({})",
741                config.target_task_idx,
742                tasks.len()
743            )));
744        }
745
746        let seed = config.seed.unwrap_or(0);
747        let rng = StdRng::seed_from_u64(seed);
748
749        // Fit a GP for each task that has observations.
750        let mut surrogates = Vec::with_capacity(tasks.len());
751        for task in &tasks {
752            let gp_config = GpSurrogateConfig {
753                noise_variance: 1e-4,
754                optimize_hyperparams: false,
755                ..Default::default()
756            };
757            let mut gp = GpSurrogate::new(Box::new(RbfKernel::default()), gp_config);
758            if task.observations_x.nrows() >= 2 {
759                let _ = gp.fit(&task.observations_x, &task.observations_y);
760            }
761            surrogates.push(gp);
762        }
763
764        // Compute task-similarity weights relative to the target task.
765        let target_bounds = &tasks[config.target_task_idx].bounds;
766        let temp = config.similarity_temperature.max(1e-10);
767
768        let mut raw_weights = Vec::with_capacity(tasks.len());
769        for (i, task) in tasks.iter().enumerate() {
770            if i == config.target_task_idx {
771                raw_weights.push(1.0_f64); // target task always weight 1.
772            } else {
773                // Compute spatial overlap.
774                let n_in_bounds: usize = (0..task.observations_x.nrows())
775                    .filter(|&j| {
776                        task.observations_x
777                            .row(j)
778                            .iter()
779                            .zip(target_bounds.iter())
780                            .all(|(&v, &(lo, hi))| v >= lo && v <= hi)
781                    })
782                    .count();
783                let frac = n_in_bounds as f64 / task.observations_x.nrows().max(1) as f64;
784                raw_weights.push((frac / temp).exp());
785            }
786        }
787        let weight_sum = raw_weights.iter().sum::<f64>().max(1e-10);
788        let task_weights: Vec<f64> = raw_weights.iter().map(|w| w / weight_sum).collect();
789
790        Ok(Self {
791            tasks,
792            config,
793            surrogates,
794            rng,
795            f_best: f64::INFINITY,
796            target_obs: Vec::new(),
797            best_history: Vec::new(),
798            task_weights,
799        })
800    }
801
802    /// Suggest the next point to evaluate on the target task.
803    pub fn ask(&mut self) -> OptimizeResult<Vec<f64>> {
804        let target_idx = self.config.target_task_idx;
805        let bounds = &self.tasks[target_idx].bounds;
806
807        if self.target_obs.len() < self.config.n_initial {
808            let x: Vec<f64> = bounds
809                .iter()
810                .map(|&(lo, hi)| lo + self.rng.random::<f64>() * (hi - lo))
811                .collect();
812            return Ok(x);
813        }
814
815        let candidates = generate_samples(
816            self.config.acq_n_candidates,
817            bounds,
818            SamplingStrategy::LatinHypercube,
819            None,
820        )?;
821
822        let acq = ExpectedImprovement::new(self.f_best, 0.01);
823
824        let mut best_val = f64::NEG_INFINITY;
825        let mut best_x = candidates.row(0).to_vec();
826
827        for i in 0..candidates.nrows() {
828            let row = candidates.row(i);
829            let mut val = 0.0_f64;
830
831            for (t, (gp, w)) in self
832                .surrogates
833                .iter()
834                .zip(self.task_weights.iter())
835                .enumerate()
836            {
837                if gp.n_train() == 0 {
838                    continue;
839                }
840                // For source tasks, only evaluate if the candidate is in their domain.
841                let in_domain = if t != target_idx {
842                    row.iter()
843                        .zip(self.tasks[t].bounds.iter())
844                        .all(|(&v, &(lo, hi))| v >= lo && v <= hi)
845                } else {
846                    true
847                };
848
849                if in_domain {
850                    let acq_val = acq.evaluate(&row, gp).unwrap_or(0.0);
851                    val += w * acq_val;
852                }
853            }
854
855            if val > best_val {
856                best_val = val;
857                best_x = row.to_vec();
858            }
859        }
860
861        Ok(best_x)
862    }
863
864    /// Record an observation on the target task.
865    pub fn tell(&mut self, x: Vec<f64>, y: f64) -> OptimizeResult<()> {
866        let target_idx = self.config.target_task_idx;
867        let ndim = self.tasks[target_idx].bounds.len();
868
869        if y < self.f_best {
870            self.f_best = y;
871        }
872        self.best_history.push(self.f_best);
873
874        let obs = WarmStartObs {
875            x: Array1::from_vec(x.clone()),
876            y,
877        };
878        self.target_obs.push(obs);
879
880        // Refit target GP.
881        let n = self.target_obs.len();
882        let mut x_rows = Vec::with_capacity(n * ndim);
883        let mut y_vec = Vec::with_capacity(n);
884        for o in &self.target_obs {
885            x_rows.extend(o.x.iter().copied());
886            y_vec.push(o.y);
887        }
888        let x_mat = Array2::from_shape_vec((n, ndim), x_rows)
889            .map_err(|e| OptimizeError::ComputationError(format!("shape: {}", e)))?;
890        let y_arr = Array1::from_vec(y_vec);
891        self.surrogates[target_idx].fit(&x_mat, &y_arr)?;
892
893        Ok(())
894    }
895
896    /// Run the full optimization loop on the target task.
897    pub fn optimize<F>(&mut self, mut objective: F) -> OptimizeResult<WarmStartResult>
898    where
899        F: FnMut(&[f64]) -> f64,
900    {
901        for iter in 0..self.config.n_calls {
902            let x = self.ask()?;
903            let y = objective(&x);
904            let _ = iter;
905            self.tell(x, y)?;
906        }
907
908        let (x_best, f_best) = self
909            .target_obs
910            .iter()
911            .min_by(|a, b| a.y.partial_cmp(&b.y).unwrap_or(std::cmp::Ordering::Equal))
912            .map(|o| (o.x.clone(), o.y))
913            .ok_or_else(|| OptimizeError::ComputationError("No observations".into()))?;
914
915        Ok(WarmStartResult {
916            x_best,
917            f_best,
918            observations: self.target_obs.clone(),
919            n_evals: self.target_obs.len(),
920            best_history: self.best_history.clone(),
921        })
922    }
923}
924
925// ---------------------------------------------------------------------------
926// Convenience function
927// ---------------------------------------------------------------------------
928
929/// Run Bayesian optimization with warm-starting from prior runs.
930///
931/// # Arguments
932///
933/// * `objective` - The objective function to minimize.
934/// * `bounds` - Search space bounds.
935/// * `prior_runs` - Prior optimization runs to warm-start from.
936/// * `n_calls` - Number of evaluations on the target task.
937/// * `seed` - Optional random seed.
938pub fn warm_start_optimize<F>(
939    objective: F,
940    bounds: Vec<(f64, f64)>,
941    prior_runs: Vec<PriorRun>,
942    n_calls: usize,
943    seed: Option<u64>,
944) -> OptimizeResult<WarmStartResult>
945where
946    F: FnMut(&[f64]) -> f64,
947{
948    let config = WarmStartConfig {
949        prior_runs,
950        seed,
951        n_initial: (n_calls / 4).max(3),
952        ..Default::default()
953    };
954    let mut bo = WarmStartBo::new(bounds, config)?;
955    bo.optimize(objective, n_calls)
956}
957
958// ---------------------------------------------------------------------------
959// Tests
960// ---------------------------------------------------------------------------
961
962#[cfg(test)]
963mod tests {
964    use super::*;
965    use scirs2_core::ndarray::{Array1, Array2};
966
967    fn make_prior_run(shift: f64) -> PriorRun {
968        let x = Array2::from_shape_vec((5, 1), vec![0.0, 1.0, 2.0, 3.0, 4.0]).expect("shape");
969        let y = Array1::from_vec(
970            x.column(0)
971                .iter()
972                .map(|&v| (v - shift).powi(2))
973                .collect::<Vec<_>>(),
974        );
975        PriorRun {
976            x,
977            y,
978            bounds: vec![(0.0, 4.0)],
979            weight: 1.0,
980        }
981    }
982
983    #[test]
984    fn test_warm_start_bo_runs() {
985        let prior = make_prior_run(2.0);
986        let config = WarmStartConfig {
987            prior_runs: vec![prior],
988            n_initial: 3,
989            seed: Some(42),
990            ..Default::default()
991        };
992        let mut bo = WarmStartBo::new(vec![(0.0, 4.0)], config).expect("create");
993        let result = bo
994            .optimize(|x: &[f64]| (x[0] - 2.0_f64).powi(2), 8)
995            .expect("optimize");
996        assert!(result.n_evals > 0, "should have evaluations");
997        assert!(result.f_best.is_finite(), "best value should be finite");
998        assert!(result.f_best >= 0.0, "squared distance is non-negative");
999    }
1000
1001    #[test]
1002    fn test_warm_start_rescale_strategy() {
1003        let prior = make_prior_run(1.5);
1004        let config = WarmStartConfig {
1005            prior_runs: vec![prior],
1006            blend_strategy: BlendStrategy::RescaleAndInject,
1007            n_initial: 2,
1008            seed: Some(7),
1009            ..Default::default()
1010        };
1011        let mut bo = WarmStartBo::new(vec![(0.0, 4.0)], config).expect("create");
1012        let result = bo
1013            .optimize(|x: &[f64]| (x[0] - 1.5_f64).powi(2), 6)
1014            .expect("optimize");
1015        assert!(result.f_best.is_finite());
1016    }
1017
1018    #[test]
1019    fn test_warm_start_hyperparam_only_strategy() {
1020        let prior = make_prior_run(3.0);
1021        let config = WarmStartConfig {
1022            prior_runs: vec![prior],
1023            blend_strategy: BlendStrategy::HyperparamOnly,
1024            n_initial: 3,
1025            seed: Some(99),
1026            ..Default::default()
1027        };
1028        let mut bo = WarmStartBo::new(vec![(0.0, 4.0)], config).expect("create");
1029        let result = bo
1030            .optimize(|x: &[f64]| (x[0] - 3.0_f64).powi(2), 6)
1031            .expect("optimize");
1032        assert!(result.f_best.is_finite());
1033    }
1034
1035    #[test]
1036    fn test_meta_learner_suggests_finite_hyperparams() {
1037        let prior = make_prior_run(1.0);
1038        let meta = MetaLearner::from_runs(&[prior]);
1039        let (ls, sv, noise) = meta.suggest_hyperparams(&[(0.0, 4.0)]);
1040        assert!(ls > 0.0 && ls.is_finite());
1041        assert!(sv > 0.0 && sv.is_finite());
1042        assert!(noise > 0.0 && noise.is_finite());
1043    }
1044
1045    #[test]
1046    fn test_task_similarity() {
1047        let prior = make_prior_run(0.0);
1048        // All points are in [0, 4] → should have positive similarity.
1049        let sim = MetaLearner::task_similarity(&prior, &[(0.0, 4.0)]);
1050        assert!(sim > 0.0 && sim <= 1.0, "similarity={}", sim);
1051
1052        // No overlap: target is [-10, -5].
1053        let sim_none = MetaLearner::task_similarity(&prior, &[(-10.0, -5.0)]);
1054        assert_eq!(sim_none, 0.0);
1055    }
1056
1057    #[test]
1058    fn test_multi_task_bo_runs() {
1059        let src_x = Array2::from_shape_vec((4, 1), vec![0.0, 1.0, 2.0, 3.0]).expect("shape");
1060        let src_y = Array1::from_vec(vec![4.0, 1.0, 0.0, 1.0]);
1061        let source_task = Task {
1062            name: "source".into(),
1063            bounds: vec![(0.0, 4.0)],
1064            observations_x: src_x,
1065            observations_y: src_y,
1066        };
1067        let target_task = Task {
1068            name: "target".into(),
1069            bounds: vec![(0.0, 4.0)],
1070            observations_x: Array2::zeros((0, 1)),
1071            observations_y: Array1::zeros(0),
1072        };
1073        let config = MultiTaskBoConfig {
1074            target_task_idx: 1,
1075            n_calls: 6,
1076            n_initial: 3,
1077            seed: Some(42),
1078            ..Default::default()
1079        };
1080        let mut mtbo = MultiTaskBo::new(vec![source_task, target_task], config).expect("create");
1081        let result = mtbo
1082            .optimize(|x: &[f64]| (x[0] - 2.0_f64).powi(2))
1083            .expect("optimize");
1084        assert!(result.n_evals > 0);
1085        assert!(result.f_best.is_finite());
1086    }
1087
1088    #[test]
1089    fn test_warm_start_optimize_fn() {
1090        let prior = make_prior_run(0.5);
1091        let result = warm_start_optimize(
1092            |x: &[f64]| (x[0] - 0.5_f64).powi(2),
1093            vec![(0.0, 4.0)],
1094            vec![prior],
1095            8,
1096            Some(42),
1097        )
1098        .expect("optimize");
1099        assert!(result.f_best.is_finite());
1100        assert!(result.n_evals > 0);
1101    }
1102}