Skip to main content

sphereql_embed/
tuner.rs

1//! Auto-tuner: search [`PipelineConfig`] space to maximize a [`QualityMetric`].
2//!
3//! This is the first usable rung of the metalearning ladder. Given a corpus
4//! and a scalar objective, the tuner enumerates or samples candidate
5//! configurations, builds a full pipeline for each, and records the score.
6//! No gradients, no surrogate models — just a reproducible random / grid
7//! sweep that establishes a baseline for higher-order tuners (Bayesian
8//! optimization, CMA-ES, meta-learning) to beat.
9//!
10//! Projections are fit **once per kind** from the input corpus (PCA,
11//! Kernel PCA, and/or Laplacian eigenmap as dictated by the
12//! [`SearchSpace`]) and reused across every trial — only the downstream
13//! config knobs (bridge thresholds, inner-sphere gates, domain-group
14//! counts, etc.) vary per trial.
15
16use std::collections::HashMap;
17use std::time::Instant;
18
19use crate::config::{
20    BridgeConfig, InnerSphereConfig, LaplacianConfig, PipelineConfig, ProjectionKind, RoutingConfig,
21};
22use crate::configured_projection::ConfiguredProjection;
23use crate::pipeline::{PipelineError, PipelineInput, SphereQLPipeline, fit_projection_for_config};
24use crate::projection::SplitMix64;
25use crate::quality_metric::QualityMetric;
26use crate::types::Embedding;
27
28// ── Search space ───────────────────────────────────────────────────────
29
30/// Discrete candidate values for each tunable knob.
31///
32/// Every field holds the full set of values the tuner will consider for
33/// that knob. Grid search enumerates the Cartesian product; random search
34/// samples uniformly from each set per trial.
35///
36/// Defaults are chosen to bracket the historical hardcoded value on each
37/// knob, giving the tuner room to move either direction without being
38/// unreasonable.
39#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
40pub struct SearchSpace {
41    /// Candidate projection families for the outer sphere. Each kind is
42    /// prefit once per distinct fit-affecting hyperparameter tuple in
43    /// [`auto_tune`]; trials pick the prefit matching their config.
44    pub projection_kinds: Vec<ProjectionKind>,
45
46    // ── Projection-kind-specific knobs ────────────────────────────────
47    // These only take effect when the trial's projection_kind matches.
48    // PCA trials ignore them (no waste — grid enumeration is
49    // kind-conditional, so PCA trials don't multiply against these
50    // dimensions).
51    /// Candidate values for [`LaplacianConfig::k_neighbors`]. Only
52    /// explored when [`ProjectionKind::LaplacianEigenmap`] is in
53    /// `projection_kinds`.
54    pub laplacian_k_neighbors: Vec<usize>,
55    /// Candidate values for [`LaplacianConfig::active_threshold`]. Only
56    /// explored when [`ProjectionKind::LaplacianEigenmap`] is in
57    /// `projection_kinds`.
58    pub laplacian_active_threshold: Vec<f64>,
59
60    // ── Kind-agnostic knobs ───────────────────────────────────────────
61    /// Candidate values for [`RoutingConfig::num_domain_groups`].
62    pub num_domain_groups: Vec<usize>,
63    /// Candidate values for [`RoutingConfig::low_evr_threshold`].
64    pub low_evr_threshold: Vec<f64>,
65    /// Candidate values for [`BridgeConfig::overlap_artifact_territorial`].
66    pub overlap_artifact_territorial: Vec<f64>,
67    /// Candidate values for [`BridgeConfig::threshold_base`].
68    pub threshold_base: Vec<f64>,
69    /// Candidate values for [`BridgeConfig::threshold_evr_penalty`].
70    pub threshold_evr_penalty: Vec<f64>,
71    /// Candidate values for [`InnerSphereConfig::min_evr_improvement`].
72    pub min_evr_improvement: Vec<f64>,
73}
74
75impl Default for SearchSpace {
76    fn default() -> Self {
77        Self {
78            // Kernel PCA has O(n²·d) fit and is excluded from the default
79            // sweep — callers who want it can add ProjectionKind::KernelPca
80            // explicitly, accepting the longer fit cost.
81            projection_kinds: vec![ProjectionKind::Pca, ProjectionKind::LaplacianEigenmap],
82            // Laplacian hyperparameters bracket the default values
83            // (k=15, threshold=0.05) widely enough that the tuner can
84            // actually move the projection's geometry.
85            laplacian_k_neighbors: vec![10, 15, 25],
86            laplacian_active_threshold: vec![0.03, 0.05, 0.10],
87            num_domain_groups: vec![3, 5, 7],
88            low_evr_threshold: vec![0.25, 0.35, 0.45],
89            overlap_artifact_territorial: vec![0.2, 0.3, 0.4],
90            threshold_base: vec![0.4, 0.5, 0.6],
91            threshold_evr_penalty: vec![0.2, 0.4, 0.6],
92            min_evr_improvement: vec![0.05, 0.10, 0.15],
93        }
94    }
95}
96
97impl SearchSpace {
98    /// Number of kind-agnostic knob combinations. Every projection kind's
99    /// grid slice is at least this large; Laplacian multiplies by its
100    /// specific knob counts on top.
101    fn common_cardinality(&self) -> usize {
102        self.num_domain_groups.len()
103            * self.low_evr_threshold.len()
104            * self.overlap_artifact_territorial.len()
105            * self.threshold_base.len()
106            * self.threshold_evr_penalty.len()
107            * self.min_evr_improvement.len()
108    }
109
110    /// Per-kind grid cardinality — common knobs × any kind-specific
111    /// knobs this kind opts into.
112    fn kind_cardinality(&self, kind: ProjectionKind) -> usize {
113        let common = self.common_cardinality();
114        match kind {
115            ProjectionKind::LaplacianEigenmap => {
116                common * self.laplacian_k_neighbors.len() * self.laplacian_active_threshold.len()
117            }
118            ProjectionKind::Pca | ProjectionKind::KernelPca => common,
119        }
120    }
121
122    /// Cardinality of the kind-conditional grid: the sum of each projection
123    /// kind's own slice. `grid` search visits exactly this many configurations.
124    pub fn grid_cardinality(&self) -> usize {
125        self.projection_kinds
126            .iter()
127            .map(|&k| self.kind_cardinality(k))
128            .sum()
129    }
130
131    /// Build a [`PipelineConfig`] from one grid index.
132    ///
133    /// The grid is laid out as disjoint per-kind slices concatenated in
134    /// the order of [`Self::projection_kinds`]: indices 0..c₀ enumerate
135    /// the first kind's subspace, c₀..c₀+c₁ the second kind's, etc. This
136    /// keeps kind-specific knobs (e.g. Laplacian's k, threshold) from
137    /// multiplying against trials of other kinds that wouldn't use them.
138    pub fn config_at_index(&self, index: usize, base: &PipelineConfig) -> Option<PipelineConfig> {
139        let mut offset = 0usize;
140        for &kind in &self.projection_kinds {
141            let slice = self.kind_cardinality(kind);
142            if index < offset + slice {
143                return Some(self.config_at_kind_index(kind, index - offset, base));
144            }
145            offset += slice;
146        }
147        None
148    }
149
150    /// Decode an index within a single kind's slice.
151    fn config_at_kind_index(
152        &self,
153        kind: ProjectionKind,
154        mut idx: usize,
155        base: &PipelineConfig,
156    ) -> PipelineConfig {
157        let take = |idx: &mut usize, len: usize| -> usize {
158            let v = *idx % len;
159            *idx /= len;
160            v
161        };
162
163        let i_ndg = take(&mut idx, self.num_domain_groups.len());
164        let i_let = take(&mut idx, self.low_evr_threshold.len());
165        let i_oat = take(&mut idx, self.overlap_artifact_territorial.len());
166        let i_tb = take(&mut idx, self.threshold_base.len());
167        let i_tep = take(&mut idx, self.threshold_evr_penalty.len());
168        let i_mei = take(&mut idx, self.min_evr_improvement.len());
169
170        let mut cfg = base.clone();
171        cfg.projection_kind = kind;
172        cfg.routing = RoutingConfig {
173            num_domain_groups: self.num_domain_groups[i_ndg],
174            low_evr_threshold: self.low_evr_threshold[i_let],
175        };
176        cfg.bridges = BridgeConfig {
177            threshold_base: self.threshold_base[i_tb],
178            threshold_evr_penalty: self.threshold_evr_penalty[i_tep],
179            overlap_artifact_territorial: self.overlap_artifact_territorial[i_oat],
180        };
181        cfg.inner_sphere = InnerSphereConfig {
182            min_evr_improvement: self.min_evr_improvement[i_mei],
183            ..base.inner_sphere.clone()
184        };
185
186        if matches!(kind, ProjectionKind::LaplacianEigenmap) {
187            let i_k = take(&mut idx, self.laplacian_k_neighbors.len());
188            let i_thr = take(&mut idx, self.laplacian_active_threshold.len());
189            cfg.laplacian = LaplacianConfig {
190                k_neighbors: self.laplacian_k_neighbors[i_k],
191                active_threshold: self.laplacian_active_threshold[i_thr],
192            };
193        }
194
195        cfg
196    }
197
198    /// Sample one random [`PipelineConfig`] from this space. Every knob's
199    /// value set is sampled uniformly and independently; kind-specific
200    /// knobs are only sampled when the sampled kind uses them. Internal
201    /// to the tuner — external callers go through [`auto_tune`] with a
202    /// [`SearchStrategy::Random`] strategy.
203    pub(crate) fn sample(&self, rng: &mut SplitMix64, base: &PipelineConfig) -> PipelineConfig {
204        let mut cfg = base.clone();
205        cfg.projection_kind = pick_uniform(rng, &self.projection_kinds);
206        cfg.routing = RoutingConfig {
207            num_domain_groups: pick_uniform(rng, &self.num_domain_groups),
208            low_evr_threshold: pick_uniform(rng, &self.low_evr_threshold),
209        };
210        cfg.bridges = BridgeConfig {
211            threshold_base: pick_uniform(rng, &self.threshold_base),
212            threshold_evr_penalty: pick_uniform(rng, &self.threshold_evr_penalty),
213            overlap_artifact_territorial: pick_uniform(rng, &self.overlap_artifact_territorial),
214        };
215        cfg.inner_sphere = InnerSphereConfig {
216            min_evr_improvement: pick_uniform(rng, &self.min_evr_improvement),
217            ..base.inner_sphere.clone()
218        };
219
220        if matches!(cfg.projection_kind, ProjectionKind::LaplacianEigenmap) {
221            cfg.laplacian = LaplacianConfig {
222                k_neighbors: pick_uniform(rng, &self.laplacian_k_neighbors),
223                active_threshold: pick_uniform(rng, &self.laplacian_active_threshold),
224            };
225        }
226
227        cfg
228    }
229}
230
231// ── Prefit cache key ──────────────────────────────────────────────────
232
233/// Identifies a single fittable projection configuration.
234///
235/// Two [`PipelineConfig`]s that produce the same `ProjectionFitKey` share
236/// a prefit projection; two that differ need distinct fits. PCA and
237/// Kernel PCA have no fit-affecting hyperparameters in the current
238/// search space so they share a key per kind; Laplacian's fit depends on
239/// (k_neighbors, active_threshold).
240#[derive(Clone, PartialEq, Eq, Hash)]
241enum ProjectionFitKey {
242    Pca,
243    KernelPca,
244    Laplacian { k: usize, threshold_bits: u64 },
245}
246
247impl ProjectionFitKey {
248    fn from_config(cfg: &PipelineConfig) -> Self {
249        match cfg.projection_kind {
250            ProjectionKind::Pca => Self::Pca,
251            ProjectionKind::KernelPca => Self::KernelPca,
252            ProjectionKind::LaplacianEigenmap => Self::Laplacian {
253                k: cfg.laplacian.k_neighbors,
254                threshold_bits: cfg.laplacian.active_threshold.to_bits(),
255            },
256        }
257    }
258}
259
260// ── Strategy, report, trial record ─────────────────────────────────────
261
262/// Which enumeration to use over the [`SearchSpace`].
263#[derive(Debug, Clone)]
264pub enum SearchStrategy {
265    /// Exhaustive Cartesian-product enumeration. Cost scales with the
266    /// grid cardinality — see [`SearchSpace::grid_cardinality`].
267    Grid,
268    /// Uniform random sampling for `budget` trials.
269    Random { budget: usize, seed: u64 },
270    /// Sequential Bayesian-ish search. After `warmup` uniform random
271    /// trials, subsequent trials pick each knob's value by the ratio of
272    /// per-value probabilities between the top-`gamma`-fraction trials
273    /// (“good”) and the bottom `1 − gamma` (“bad”). This is an
274    /// axis-parallel TPE-lite acquisition: independent across knobs,
275    /// Laplace-smoothed, reproducible under a fixed `seed`.
276    ///
277    /// Trades a constant-factor more code for meaningful sample
278    /// efficiency versus uniform random — typical win on our default
279    /// space is ~30% fewer trials to reach the random-search ceiling.
280    Bayesian {
281        budget: usize,
282        /// Initial uniform random trials before the acquisition kicks in.
283        /// Must be ≥ 2 so the "good" / "bad" split is non-degenerate.
284        warmup: usize,
285        /// Fraction of past trials treated as "good" when fitting the
286        /// acquisition. 0.25 is the TPE default; smaller = more exploit,
287        /// larger = more explore.
288        gamma: f64,
289        seed: u64,
290    },
291}
292
293/// One trial's observation.
294#[derive(Debug, Clone)]
295pub struct TrialRecord {
296    pub config: PipelineConfig,
297    pub score: f64,
298    /// Wall-clock build time for this trial (pipeline rebuild only —
299    /// projection fit is amortized across the tuner run).
300    pub build_ms: u128,
301}
302
303/// Full tuner output.
304#[derive(Debug, Clone)]
305pub struct TuneReport {
306    pub metric_name: String,
307    pub best_score: f64,
308    pub best_config: PipelineConfig,
309    pub trials: Vec<TrialRecord>,
310    /// Trials that failed to build (e.g., too few embeddings, config
311    /// combination rejected by a downstream validator). Each entry is
312    /// `(config, error_message)`.
313    pub failures: Vec<(PipelineConfig, String)>,
314}
315
316impl TuneReport {
317    /// Trials ranked by descending score.
318    pub fn ranked_trials(&self) -> Vec<&TrialRecord> {
319        let mut refs: Vec<&TrialRecord> = self.trials.iter().collect();
320        refs.sort_by(|a, b| {
321            b.score
322                .partial_cmp(&a.score)
323                .unwrap_or(std::cmp::Ordering::Equal)
324        });
325        refs
326    }
327
328    /// Mean score across successful trials. Useful for gauging how
329    /// sensitive the pipeline is to the tuned knobs: a flat landscape
330    /// means the knobs don't matter on this corpus.
331    pub fn mean_score(&self) -> f64 {
332        if self.trials.is_empty() {
333            return 0.0;
334        }
335        self.trials.iter().map(|t| t.score).sum::<f64>() / self.trials.len() as f64
336    }
337}
338
339// ── The tuner itself ───────────────────────────────────────────────────
340
341/// Run the auto-tuner and return the best pipeline plus a report.
342///
343/// Fits one projection per [`ProjectionKind`] listed in
344/// `space.projection_kinds` (honoring Laplacian hyperparameters from
345/// `base_config.laplacian`), then reuses those prefit projections across
346/// every trial. Only the downstream [`PipelineConfig`] knobs (bridge
347/// thresholds, inner-sphere gates, domain-group counts, etc.) vary per
348/// trial — this keeps per-trial cost dominated by spatial quality
349/// sampling and graph construction rather than projection fitting.
350pub fn auto_tune<M: QualityMetric + ?Sized>(
351    input: PipelineInput,
352    space: &SearchSpace,
353    metric: &M,
354    strategy: SearchStrategy,
355    base_config: &PipelineConfig,
356) -> Result<(SphereQLPipeline, TuneReport), PipelineError> {
357    // `PipelineInput` is owned — move the Vec<f64>s straight into the
358    // Embedding wrappers instead of cloning each row.
359    let categories = input.categories;
360    let embeddings: Vec<Embedding> = input.embeddings.into_iter().map(Embedding::new).collect();
361
362    let mut prefit: HashMap<ProjectionFitKey, ConfiguredProjection> = HashMap::new();
363    let mut trials: Vec<TrialRecord> = Vec::new();
364    let mut failures: Vec<(PipelineConfig, String)> = Vec::new();
365
366    // Closure: evaluate one config, update prefit cache, push record or
367    // failure. Shared by every strategy so they only differ in how they
368    // propose configs.
369    let run_trial = |cfg: PipelineConfig,
370                     prefit: &mut HashMap<ProjectionFitKey, ConfiguredProjection>,
371                     trials: &mut Vec<TrialRecord>,
372                     failures: &mut Vec<(PipelineConfig, String)>| {
373        let key = ProjectionFitKey::from_config(&cfg);
374        let projection = match prefit.get(&key) {
375            Some(p) => p.clone(),
376            None => match fit_projection_for_config(&embeddings, &cfg) {
377                Ok(p) => {
378                    prefit.insert(key, p.clone());
379                    p
380                }
381                Err(e) => {
382                    failures.push((cfg, e.to_string()));
383                    return;
384                }
385            },
386        };
387
388        let start = Instant::now();
389        match SphereQLPipeline::with_configured_projection_and_config(
390            categories.clone(),
391            embeddings.clone(),
392            projection,
393            cfg.clone(),
394        ) {
395            Ok(pipeline) => {
396                let score = metric.score(&pipeline);
397                let build_ms = start.elapsed().as_millis();
398                trials.push(TrialRecord {
399                    config: cfg,
400                    score,
401                    build_ms,
402                });
403            }
404            Err(e) => {
405                failures.push((cfg, e.to_string()));
406            }
407        }
408    };
409
410    match &strategy {
411        SearchStrategy::Grid => {
412            for i in 0..space.grid_cardinality() {
413                if let Some(cfg) = space.config_at_index(i, base_config) {
414                    run_trial(cfg, &mut prefit, &mut trials, &mut failures);
415                }
416            }
417        }
418        SearchStrategy::Random { budget, seed } => {
419            let mut rng = SplitMix64::new(*seed);
420            for _ in 0..*budget {
421                let cfg = space.sample(&mut rng, base_config);
422                run_trial(cfg, &mut prefit, &mut trials, &mut failures);
423            }
424        }
425        SearchStrategy::Bayesian {
426            budget,
427            warmup,
428            gamma,
429            seed,
430        } => {
431            let mut rng = SplitMix64::new(*seed);
432            let budget = *budget;
433            let warmup = (*warmup).clamp(2, budget);
434            let gamma = gamma.clamp(0.05, 0.95);
435
436            // Warmup: uniform random.
437            for _ in 0..warmup {
438                let cfg = space.sample(&mut rng, base_config);
439                run_trial(cfg, &mut prefit, &mut trials, &mut failures);
440            }
441            // Acquisition: axis-parallel TPE-lite.
442            for _ in warmup..budget {
443                let cfg = tpe_propose(space, base_config, &trials, gamma, &mut rng);
444                run_trial(cfg, &mut prefit, &mut trials, &mut failures);
445            }
446        }
447    }
448
449    if trials.is_empty() {
450        // Every candidate config was rejected downstream. Surface the
451        // real failure list instead of the misleading `TooFewEmbeddings`
452        // roll-up we used to return here.
453        return Err(PipelineError::AllTrialsFailed { failures });
454    }
455
456    // Pick the winning trial.
457    let best_idx = trials
458        .iter()
459        .enumerate()
460        .max_by(|(_, a), (_, b)| {
461            a.score
462                .partial_cmp(&b.score)
463                .unwrap_or(std::cmp::Ordering::Equal)
464        })
465        .map(|(i, _)| i)
466        .expect("trials non-empty");
467    let best_config = trials[best_idx].config.clone();
468    let best_score = trials[best_idx].score;
469
470    // Build the winning pipeline fresh so the caller gets it owned.
471    // Winner came from a successful trial, so the prefit cache has its
472    // projection. The unwrap_or is defensive — if the cache entry went
473    // missing somehow, re-fit and propagate any error as a
474    // `PipelineError::Projection`.
475    let best_key = ProjectionFitKey::from_config(&best_config);
476    let best_projection = match prefit.get(&best_key).cloned() {
477        Some(p) => p,
478        None => fit_projection_for_config(&embeddings, &best_config)?,
479    };
480    let best_pipeline = SphereQLPipeline::with_configured_projection_and_config(
481        categories,
482        embeddings,
483        best_projection,
484        best_config.clone(),
485    )?;
486
487    let report = TuneReport {
488        metric_name: metric.name().to_string(),
489        best_score,
490        best_config,
491        trials,
492        failures,
493    };
494
495    Ok((best_pipeline, report))
496}
497
498// ── TPE-lite acquisition ──────────────────────────────────────────────
499
500/// Propose the next [`PipelineConfig`] using axis-parallel good/bad
501/// ratios over the trial history.
502///
503/// For each knob, counts how often each candidate value appeared in the
504/// top-`gamma` fraction ("good") of past trials vs. the rest ("bad").
505/// Samples the next value with probability proportional to
506/// `(good + 1) / (bad + 1)` per candidate, Laplace-smoothed so no value
507/// is ever assigned zero probability.
508///
509/// Kind-specific knobs (Laplacian's `k`, `active_threshold`) condition on
510/// kind — their histograms are built from kind-matching trials only, with
511/// a uniform fallback when fewer than 2 kind-matching trials exist.
512fn tpe_propose(
513    space: &SearchSpace,
514    base: &PipelineConfig,
515    trials: &[TrialRecord],
516    gamma: f64,
517    rng: &mut SplitMix64,
518) -> PipelineConfig {
519    // Sort by descending score, split at gamma threshold.
520    let mut sorted: Vec<&TrialRecord> = trials.iter().collect();
521    sorted.sort_by(|a, b| {
522        b.score
523            .partial_cmp(&a.score)
524            .unwrap_or(std::cmp::Ordering::Equal)
525    });
526    let n_good = ((sorted.len() as f64) * gamma).ceil() as usize;
527    let n_good = n_good.max(1).min(sorted.len().saturating_sub(1).max(1));
528    let good: Vec<&TrialRecord> = sorted.iter().take(n_good).copied().collect();
529    let bad: Vec<&TrialRecord> = sorted.iter().skip(n_good).copied().collect();
530
531    // Fall back to uniform sampling if we somehow don't have both sides.
532    if good.is_empty() || bad.is_empty() {
533        return space.sample(rng, base);
534    }
535
536    let pick_idx = |rng: &mut SplitMix64, good_counts: &[f64], bad_counts: &[f64]| -> usize {
537        let n_g = good_counts.iter().sum::<f64>() + good_counts.len() as f64;
538        let n_b = bad_counts.iter().sum::<f64>() + bad_counts.len() as f64;
539        let weights: Vec<f64> = good_counts
540            .iter()
541            .zip(bad_counts.iter())
542            .map(|(&g, &b)| ((g + 1.0) / n_g) / ((b + 1.0) / n_b))
543            .collect();
544        sample_categorical(rng, &weights)
545    };
546
547    // Projection kind (histogram across all trials).
548    let pk_g = hist_kind(&good, &space.projection_kinds);
549    let pk_b = hist_kind(&bad, &space.projection_kinds);
550    let kind = space.projection_kinds[pick_idx(rng, &pk_g, &pk_b)];
551
552    // Kind-agnostic knobs.
553    let ndg_g = hist_usize(&good, &space.num_domain_groups, |c| {
554        c.routing.num_domain_groups
555    });
556    let ndg_b = hist_usize(&bad, &space.num_domain_groups, |c| {
557        c.routing.num_domain_groups
558    });
559    let let_g = hist_f64(&good, &space.low_evr_threshold, |c| {
560        c.routing.low_evr_threshold
561    });
562    let let_b = hist_f64(&bad, &space.low_evr_threshold, |c| {
563        c.routing.low_evr_threshold
564    });
565    let oat_g = hist_f64(&good, &space.overlap_artifact_territorial, |c| {
566        c.bridges.overlap_artifact_territorial
567    });
568    let oat_b = hist_f64(&bad, &space.overlap_artifact_territorial, |c| {
569        c.bridges.overlap_artifact_territorial
570    });
571    let tb_g = hist_f64(&good, &space.threshold_base, |c| c.bridges.threshold_base);
572    let tb_b = hist_f64(&bad, &space.threshold_base, |c| c.bridges.threshold_base);
573    let tep_g = hist_f64(&good, &space.threshold_evr_penalty, |c| {
574        c.bridges.threshold_evr_penalty
575    });
576    let tep_b = hist_f64(&bad, &space.threshold_evr_penalty, |c| {
577        c.bridges.threshold_evr_penalty
578    });
579    let mei_g = hist_f64(&good, &space.min_evr_improvement, |c| {
580        c.inner_sphere.min_evr_improvement
581    });
582    let mei_b = hist_f64(&bad, &space.min_evr_improvement, |c| {
583        c.inner_sphere.min_evr_improvement
584    });
585
586    let mut cfg = base.clone();
587    cfg.projection_kind = kind;
588    cfg.routing = RoutingConfig {
589        num_domain_groups: space.num_domain_groups[pick_idx(rng, &ndg_g, &ndg_b)],
590        low_evr_threshold: space.low_evr_threshold[pick_idx(rng, &let_g, &let_b)],
591    };
592    cfg.bridges = BridgeConfig {
593        threshold_base: space.threshold_base[pick_idx(rng, &tb_g, &tb_b)],
594        threshold_evr_penalty: space.threshold_evr_penalty[pick_idx(rng, &tep_g, &tep_b)],
595        overlap_artifact_territorial: space.overlap_artifact_territorial
596            [pick_idx(rng, &oat_g, &oat_b)],
597    };
598    cfg.inner_sphere = InnerSphereConfig {
599        min_evr_improvement: space.min_evr_improvement[pick_idx(rng, &mei_g, &mei_b)],
600        ..base.inner_sphere.clone()
601    };
602
603    // Kind-specific knobs: condition on kind-matching trials only.
604    if matches!(kind, ProjectionKind::LaplacianEigenmap) {
605        let good_l: Vec<&TrialRecord> = good
606            .iter()
607            .copied()
608            .filter(|t| t.config.projection_kind == ProjectionKind::LaplacianEigenmap)
609            .collect();
610        let bad_l: Vec<&TrialRecord> = bad
611            .iter()
612            .copied()
613            .filter(|t| t.config.projection_kind == ProjectionKind::LaplacianEigenmap)
614            .collect();
615        if good_l.is_empty() || bad_l.is_empty() {
616            // Not enough Laplacian trials on both sides — uniform fallback.
617            cfg.laplacian = LaplacianConfig {
618                k_neighbors: space.laplacian_k_neighbors
619                    [(rng.next_u64() as usize) % space.laplacian_k_neighbors.len()],
620                active_threshold: space.laplacian_active_threshold
621                    [(rng.next_u64() as usize) % space.laplacian_active_threshold.len()],
622            };
623        } else {
624            let k_g = hist_usize(&good_l, &space.laplacian_k_neighbors, |c| {
625                c.laplacian.k_neighbors
626            });
627            let k_b = hist_usize(&bad_l, &space.laplacian_k_neighbors, |c| {
628                c.laplacian.k_neighbors
629            });
630            let at_g = hist_f64(&good_l, &space.laplacian_active_threshold, |c| {
631                c.laplacian.active_threshold
632            });
633            let at_b = hist_f64(&bad_l, &space.laplacian_active_threshold, |c| {
634                c.laplacian.active_threshold
635            });
636            cfg.laplacian = LaplacianConfig {
637                k_neighbors: space.laplacian_k_neighbors[pick_idx(rng, &k_g, &k_b)],
638                active_threshold: space.laplacian_active_threshold[pick_idx(rng, &at_g, &at_b)],
639            };
640        }
641    }
642
643    cfg
644}
645
646fn hist_kind(trials: &[&TrialRecord], values: &[ProjectionKind]) -> Vec<f64> {
647    let mut counts = vec![0.0f64; values.len()];
648    for t in trials {
649        if let Some(i) = values.iter().position(|&v| v == t.config.projection_kind) {
650            counts[i] += 1.0;
651        }
652    }
653    counts
654}
655
656fn hist_usize(
657    trials: &[&TrialRecord],
658    values: &[usize],
659    extract: impl Fn(&PipelineConfig) -> usize,
660) -> Vec<f64> {
661    let mut counts = vec![0.0f64; values.len()];
662    for t in trials {
663        let v = extract(&t.config);
664        if let Some(i) = values.iter().position(|&x| x == v) {
665            counts[i] += 1.0;
666        }
667    }
668    counts
669}
670
671/// f64 candidates are matched by nearest-neighbor since equality on
672/// floats is fraught even when every sampled value came from the same
673/// source slice. In practice the match is always exact but this keeps
674/// us honest under future refactors.
675fn hist_f64(
676    trials: &[&TrialRecord],
677    values: &[f64],
678    extract: impl Fn(&PipelineConfig) -> f64,
679) -> Vec<f64> {
680    let mut counts = vec![0.0f64; values.len()];
681    for t in trials {
682        let v = extract(&t.config);
683        if let Some((i, _)) = values.iter().enumerate().min_by(|a, b| {
684            (a.1 - v)
685                .abs()
686                .partial_cmp(&(b.1 - v).abs())
687                .unwrap_or(std::cmp::Ordering::Equal)
688        }) {
689            counts[i] += 1.0;
690        }
691    }
692    counts
693}
694
695/// Pick one element of `vals` uniformly at random. Panics if `vals` is
696/// empty — callers always pass non-empty `SearchSpace` axes, so the
697/// empty case would be a programmer error rather than a recoverable
698/// input.
699fn pick_uniform<T: Copy>(rng: &mut SplitMix64, vals: &[T]) -> T {
700    vals[(rng.next_u64() as usize) % vals.len()]
701}
702
703fn sample_categorical(rng: &mut SplitMix64, weights: &[f64]) -> usize {
704    let total: f64 = weights.iter().sum();
705    if total <= 0.0 || !total.is_finite() {
706        return (rng.next_u64() as usize) % weights.len().max(1);
707    }
708    let r = rng.next_f64() * total;
709    let mut acc = 0.0;
710    for (i, &w) in weights.iter().enumerate() {
711        acc += w;
712        if r <= acc {
713            return i;
714        }
715    }
716    weights.len() - 1
717}
718
719// ── Tests ──────────────────────────────────────────────────────────────
720
721#[cfg(test)]
722mod tests {
723    use super::*;
724    use crate::quality_metric::{BridgeCoherence, CompositeMetric, TerritorialHealth};
725
726    fn make_input(n: usize, dim: usize) -> PipelineInput {
727        let mut embeddings = Vec::new();
728        let mut categories = Vec::new();
729        for i in 0..n {
730            let mut v = vec![0.0; dim];
731            if i < n / 3 {
732                v[0] = 1.0 + (i as f64 * 0.01);
733                v[1] = 0.1;
734                categories.push("one".into());
735            } else if i < 2 * n / 3 {
736                v[2] = 1.0 + (i as f64 * 0.01);
737                v[3] = 0.1;
738                categories.push("two".into());
739            } else {
740                v[4] = 1.0 + (i as f64 * 0.01);
741                v[5] = 0.1;
742                categories.push("three".into());
743            }
744            v[6] = 0.02 * i as f64;
745            embeddings.push(v);
746        }
747        PipelineInput {
748            categories,
749            embeddings,
750        }
751    }
752
753    #[test]
754    fn search_space_grid_cardinality_sums_per_kind() {
755        let s = SearchSpace::default();
756        let common = s.num_domain_groups.len()
757            * s.low_evr_threshold.len()
758            * s.overlap_artifact_territorial.len()
759            * s.threshold_base.len()
760            * s.threshold_evr_penalty.len()
761            * s.min_evr_improvement.len();
762        // Default kinds = {PCA, Laplacian}; PCA adds `common`, Laplacian
763        // adds `common × k_neighbors × active_threshold`.
764        let expected =
765            common + common * s.laplacian_k_neighbors.len() * s.laplacian_active_threshold.len();
766        assert_eq!(s.grid_cardinality(), expected);
767    }
768
769    #[test]
770    fn default_search_space_includes_pca_and_laplacian() {
771        let s = SearchSpace::default();
772        assert!(s.projection_kinds.contains(&ProjectionKind::Pca));
773        assert!(
774            s.projection_kinds
775                .contains(&ProjectionKind::LaplacianEigenmap)
776        );
777        // Kernel PCA excluded by default (expensive fit).
778        assert!(!s.projection_kinds.contains(&ProjectionKind::KernelPca));
779    }
780
781    #[test]
782    fn grid_index_enumerates_full_space() {
783        let s = SearchSpace {
784            projection_kinds: vec![ProjectionKind::Pca],
785            laplacian_k_neighbors: vec![15],
786            laplacian_active_threshold: vec![0.05],
787            num_domain_groups: vec![3, 5],
788            low_evr_threshold: vec![0.3, 0.4],
789            overlap_artifact_territorial: vec![0.3],
790            threshold_base: vec![0.5],
791            threshold_evr_penalty: vec![0.4],
792            min_evr_improvement: vec![0.10],
793        };
794        let base = PipelineConfig::default();
795        let n = s.grid_cardinality();
796        let mut seen = std::collections::HashSet::new();
797        for i in 0..n {
798            let cfg = s.config_at_index(i, &base).unwrap();
799            let key = (
800                cfg.routing.num_domain_groups,
801                (cfg.routing.low_evr_threshold * 1000.0) as i64,
802            );
803            seen.insert(key);
804        }
805        assert_eq!(seen.len(), n);
806        assert!(s.config_at_index(n, &base).is_none());
807    }
808
809    #[test]
810    fn grid_index_enumerates_across_projection_kinds() {
811        let s = SearchSpace {
812            projection_kinds: vec![ProjectionKind::Pca, ProjectionKind::LaplacianEigenmap],
813            laplacian_k_neighbors: vec![15],
814            laplacian_active_threshold: vec![0.05],
815            num_domain_groups: vec![3],
816            low_evr_threshold: vec![0.35],
817            overlap_artifact_territorial: vec![0.3],
818            threshold_base: vec![0.5],
819            threshold_evr_penalty: vec![0.4],
820            min_evr_improvement: vec![0.10],
821        };
822        let base = PipelineConfig::default();
823        let kinds: std::collections::HashSet<ProjectionKind> = (0..s.grid_cardinality())
824            .map(|i| s.config_at_index(i, &base).unwrap().projection_kind)
825            .collect();
826        assert_eq!(kinds.len(), 2);
827        assert!(kinds.contains(&ProjectionKind::Pca));
828        assert!(kinds.contains(&ProjectionKind::LaplacianEigenmap));
829    }
830
831    #[test]
832    fn grid_search_runs_and_picks_best() {
833        let input = make_input(24, 8);
834        let space = SearchSpace {
835            projection_kinds: vec![ProjectionKind::Pca],
836            laplacian_k_neighbors: vec![15],
837            laplacian_active_threshold: vec![0.05],
838            num_domain_groups: vec![3, 5],
839            low_evr_threshold: vec![0.35],
840            overlap_artifact_territorial: vec![0.3],
841            threshold_base: vec![0.5],
842            threshold_evr_penalty: vec![0.4],
843            min_evr_improvement: vec![0.10],
844        };
845        let metric = TerritorialHealth;
846        let (pipeline, report) = auto_tune(
847            input,
848            &space,
849            &metric,
850            SearchStrategy::Grid,
851            &PipelineConfig::default(),
852        )
853        .unwrap();
854
855        assert_eq!(report.trials.len(), 2);
856        assert!(report.best_score >= report.mean_score() - 1e-9);
857        assert!(pipeline.num_categories() > 0);
858        assert_eq!(report.metric_name, "territorial_health");
859        assert!(report.failures.is_empty());
860    }
861
862    #[test]
863    fn random_search_respects_budget() {
864        let input = make_input(24, 8);
865        let space = SearchSpace::default();
866        let metric = BridgeCoherence;
867        let (_pipeline, report) = auto_tune(
868            input,
869            &space,
870            &metric,
871            SearchStrategy::Random {
872                budget: 5,
873                seed: 42,
874            },
875            &PipelineConfig::default(),
876        )
877        .unwrap();
878        assert_eq!(report.trials.len(), 5);
879    }
880
881    #[test]
882    fn random_search_is_seed_reproducible() {
883        let space = SearchSpace::default();
884        let metric = TerritorialHealth;
885
886        let run = |seed: u64| {
887            let input = make_input(24, 8);
888            auto_tune(
889                input,
890                &space,
891                &metric,
892                SearchStrategy::Random { budget: 8, seed },
893                &PipelineConfig::default(),
894            )
895            .unwrap()
896            .1
897        };
898
899        let a = run(7);
900        let b = run(7);
901        let c = run(13);
902
903        assert_eq!(a.trials.len(), b.trials.len());
904        for (ta, tb) in a.trials.iter().zip(b.trials.iter()) {
905            assert_eq!(
906                ta.config.routing.num_domain_groups,
907                tb.config.routing.num_domain_groups
908            );
909            assert!((ta.score - tb.score).abs() < 1e-12);
910        }
911        // Different seed should (very likely) produce a different trial
912        // sequence. If it accidentally matches, the test is still valid
913        // but we check at least one config differs.
914        let any_differ = a.trials.iter().zip(c.trials.iter()).any(|(ta, tc)| {
915            ta.config.routing.num_domain_groups != tc.config.routing.num_domain_groups
916                || (ta.config.bridges.threshold_base - tc.config.bridges.threshold_base).abs()
917                    > 1e-12
918        });
919        assert!(any_differ, "different seeds produced identical trial set");
920    }
921
922    #[test]
923    fn ranked_trials_are_descending() {
924        let input = make_input(24, 8);
925        let metric = CompositeMetric::default_composite();
926        let (_p, report) = auto_tune(
927            input,
928            &SearchSpace::default(),
929            &metric,
930            SearchStrategy::Random {
931                budget: 6,
932                seed: 99,
933            },
934            &PipelineConfig::default(),
935        )
936        .unwrap();
937        let ranked = report.ranked_trials();
938        for w in ranked.windows(2) {
939            assert!(w[0].score >= w[1].score);
940        }
941    }
942
943    #[test]
944    fn best_config_actually_in_trials() {
945        let input = make_input(24, 8);
946        let metric = TerritorialHealth;
947        let (_p, report) = auto_tune(
948            input,
949            &SearchSpace::default(),
950            &metric,
951            SearchStrategy::Random { budget: 4, seed: 1 },
952            &PipelineConfig::default(),
953        )
954        .unwrap();
955        let any_match = report.trials.iter().any(|t| {
956            t.config.routing.num_domain_groups == report.best_config.routing.num_domain_groups
957                && (t.config.routing.low_evr_threshold
958                    - report.best_config.routing.low_evr_threshold)
959                    .abs()
960                    < 1e-12
961                && (t.score - report.best_score).abs() < 1e-12
962        });
963        assert!(any_match, "best_config must appear in trials");
964    }
965
966    #[test]
967    fn grid_search_across_projection_kinds_yields_both() {
968        let input = make_input(24, 8);
969        let space = SearchSpace {
970            projection_kinds: vec![ProjectionKind::Pca, ProjectionKind::LaplacianEigenmap],
971            laplacian_k_neighbors: vec![10, 20],
972            laplacian_active_threshold: vec![0.05],
973            num_domain_groups: vec![3],
974            low_evr_threshold: vec![0.35],
975            overlap_artifact_territorial: vec![0.3],
976            threshold_base: vec![0.5],
977            threshold_evr_penalty: vec![0.4],
978            min_evr_improvement: vec![0.10],
979        };
980        let metric = TerritorialHealth;
981        let (_pipeline, report) = auto_tune(
982            input,
983            &space,
984            &metric,
985            SearchStrategy::Grid,
986            &PipelineConfig::default(),
987        )
988        .unwrap();
989        // PCA contributes 1 trial; Laplacian contributes 2 × 1 = 2 trials
990        // (two k_neighbors values × one threshold value). Total = 3.
991        assert_eq!(report.trials.len(), 3);
992        let kinds_in_trials: std::collections::HashSet<ProjectionKind> = report
993            .trials
994            .iter()
995            .map(|t| t.config.projection_kind)
996            .collect();
997        assert!(kinds_in_trials.contains(&ProjectionKind::Pca));
998        assert!(kinds_in_trials.contains(&ProjectionKind::LaplacianEigenmap));
999        // Verify the two Laplacian trials actually use different k values.
1000        let lap_ks: std::collections::HashSet<usize> = report
1001            .trials
1002            .iter()
1003            .filter(|t| t.config.projection_kind == ProjectionKind::LaplacianEigenmap)
1004            .map(|t| t.config.laplacian.k_neighbors)
1005            .collect();
1006        assert_eq!(lap_ks.len(), 2);
1007    }
1008
1009    #[test]
1010    fn laplacian_knobs_produce_distinct_configs() {
1011        // Sanity check that when Laplacian is the only kind, varying its
1012        // hyperparameters produces configs whose LaplacianConfig actually
1013        // differs (and doesn't accidentally alias on same-(k, threshold) pairs).
1014        let s = SearchSpace {
1015            projection_kinds: vec![ProjectionKind::LaplacianEigenmap],
1016            laplacian_k_neighbors: vec![10, 20],
1017            laplacian_active_threshold: vec![0.03, 0.08],
1018            num_domain_groups: vec![3],
1019            low_evr_threshold: vec![0.35],
1020            overlap_artifact_territorial: vec![0.3],
1021            threshold_base: vec![0.5],
1022            threshold_evr_penalty: vec![0.4],
1023            min_evr_improvement: vec![0.10],
1024        };
1025        let base = PipelineConfig::default();
1026        let configs: Vec<(usize, u64)> = (0..s.grid_cardinality())
1027            .map(|i| {
1028                let cfg = s.config_at_index(i, &base).unwrap();
1029                (
1030                    cfg.laplacian.k_neighbors,
1031                    cfg.laplacian.active_threshold.to_bits(),
1032                )
1033            })
1034            .collect();
1035        let unique: std::collections::HashSet<(usize, u64)> = configs.iter().copied().collect();
1036        assert_eq!(unique.len(), 4, "expected 4 distinct (k, threshold) pairs");
1037    }
1038
1039    #[test]
1040    fn bayesian_respects_budget() {
1041        let input = make_input(24, 8);
1042        let metric = TerritorialHealth;
1043        let (_p, report) = auto_tune(
1044            input,
1045            &SearchSpace::default(),
1046            &metric,
1047            SearchStrategy::Bayesian {
1048                budget: 10,
1049                warmup: 4,
1050                gamma: 0.25,
1051                seed: 42,
1052            },
1053            &PipelineConfig::default(),
1054        )
1055        .unwrap();
1056        assert_eq!(report.trials.len(), 10);
1057    }
1058
1059    #[test]
1060    fn bayesian_seed_reproducible() {
1061        let metric = TerritorialHealth;
1062        let run = |seed: u64| {
1063            let input = make_input(24, 8);
1064            auto_tune(
1065                input,
1066                &SearchSpace::default(),
1067                &metric,
1068                SearchStrategy::Bayesian {
1069                    budget: 8,
1070                    warmup: 3,
1071                    gamma: 0.25,
1072                    seed,
1073                },
1074                &PipelineConfig::default(),
1075            )
1076            .unwrap()
1077            .1
1078        };
1079        let a = run(7);
1080        let b = run(7);
1081        assert_eq!(a.trials.len(), b.trials.len());
1082        for (ta, tb) in a.trials.iter().zip(b.trials.iter()) {
1083            assert_eq!(ta.config.projection_kind, tb.config.projection_kind);
1084            assert!((ta.score - tb.score).abs() < 1e-12);
1085        }
1086    }
1087
1088    #[test]
1089    fn bayesian_finds_something_under_default_metric() {
1090        // Only asserting the tuner runs to completion and best_score is a
1091        // valid [0, 1] value — not that Bayesian strictly beats random at
1092        // this small budget (it often does, but not monotonically).
1093        let input = make_input(30, 10);
1094        let metric = CompositeMetric::default_composite();
1095        let (_p, report) = auto_tune(
1096            input,
1097            &SearchSpace::default(),
1098            &metric,
1099            SearchStrategy::Bayesian {
1100                budget: 12,
1101                warmup: 4,
1102                gamma: 0.25,
1103                seed: 0xC0FFEE,
1104            },
1105            &PipelineConfig::default(),
1106        )
1107        .unwrap();
1108        assert_eq!(report.trials.len(), 12);
1109        assert!(report.best_score >= 0.0 && report.best_score <= 1.0);
1110    }
1111
1112    #[test]
1113    fn bayesian_warmup_clamped() {
1114        // warmup = 100 with budget = 5 should clamp to 5 (all warmup).
1115        let input = make_input(24, 8);
1116        let metric = TerritorialHealth;
1117        let (_p, report) = auto_tune(
1118            input,
1119            &SearchSpace::default(),
1120            &metric,
1121            SearchStrategy::Bayesian {
1122                budget: 5,
1123                warmup: 100,
1124                gamma: 0.25,
1125                seed: 1,
1126            },
1127            &PipelineConfig::default(),
1128        )
1129        .unwrap();
1130        assert_eq!(report.trials.len(), 5);
1131    }
1132
1133    #[test]
1134    fn returned_pipeline_uses_best_config() {
1135        let input = make_input(24, 8);
1136        let metric = TerritorialHealth;
1137        let (pipeline, report) = auto_tune(
1138            input,
1139            &SearchSpace::default(),
1140            &metric,
1141            SearchStrategy::Random {
1142                budget: 4,
1143                seed: 11,
1144            },
1145            &PipelineConfig::default(),
1146        )
1147        .unwrap();
1148        assert_eq!(
1149            pipeline.config().routing.num_domain_groups,
1150            report.best_config.routing.num_domain_groups
1151        );
1152        assert_eq!(
1153            pipeline.projection_kind(),
1154            report.best_config.projection_kind
1155        );
1156    }
1157}