1use std::collections::HashMap;
17use std::time::Instant;
18
19use crate::config::{
20 BridgeConfig, InnerSphereConfig, LaplacianConfig, PipelineConfig, ProjectionKind, RoutingConfig,
21};
22use crate::configured_projection::ConfiguredProjection;
23use crate::pipeline::{PipelineError, PipelineInput, SphereQLPipeline, fit_projection_for_config};
24use crate::projection::SplitMix64;
25use crate::quality_metric::QualityMetric;
26use crate::types::Embedding;
27
28#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
40pub struct SearchSpace {
41 pub projection_kinds: Vec<ProjectionKind>,
45
46 pub laplacian_k_neighbors: Vec<usize>,
55 pub laplacian_active_threshold: Vec<f64>,
59
60 pub num_domain_groups: Vec<usize>,
63 pub low_evr_threshold: Vec<f64>,
65 pub overlap_artifact_territorial: Vec<f64>,
67 pub threshold_base: Vec<f64>,
69 pub threshold_evr_penalty: Vec<f64>,
71 pub min_evr_improvement: Vec<f64>,
73}
74
75impl Default for SearchSpace {
76 fn default() -> Self {
77 Self {
78 projection_kinds: vec![ProjectionKind::Pca, ProjectionKind::LaplacianEigenmap],
82 laplacian_k_neighbors: vec![10, 15, 25],
86 laplacian_active_threshold: vec![0.03, 0.05, 0.10],
87 num_domain_groups: vec![3, 5, 7],
88 low_evr_threshold: vec![0.25, 0.35, 0.45],
89 overlap_artifact_territorial: vec![0.2, 0.3, 0.4],
90 threshold_base: vec![0.4, 0.5, 0.6],
91 threshold_evr_penalty: vec![0.2, 0.4, 0.6],
92 min_evr_improvement: vec![0.05, 0.10, 0.15],
93 }
94 }
95}
96
97impl SearchSpace {
98 fn common_cardinality(&self) -> usize {
102 self.num_domain_groups.len()
103 * self.low_evr_threshold.len()
104 * self.overlap_artifact_territorial.len()
105 * self.threshold_base.len()
106 * self.threshold_evr_penalty.len()
107 * self.min_evr_improvement.len()
108 }
109
110 fn kind_cardinality(&self, kind: ProjectionKind) -> usize {
113 let common = self.common_cardinality();
114 match kind {
115 ProjectionKind::LaplacianEigenmap => {
116 common * self.laplacian_k_neighbors.len() * self.laplacian_active_threshold.len()
117 }
118 ProjectionKind::Pca | ProjectionKind::KernelPca => common,
119 }
120 }
121
122 pub fn grid_cardinality(&self) -> usize {
125 self.projection_kinds
126 .iter()
127 .map(|&k| self.kind_cardinality(k))
128 .sum()
129 }
130
131 pub fn config_at_index(&self, index: usize, base: &PipelineConfig) -> Option<PipelineConfig> {
139 let mut offset = 0usize;
140 for &kind in &self.projection_kinds {
141 let slice = self.kind_cardinality(kind);
142 if index < offset + slice {
143 return Some(self.config_at_kind_index(kind, index - offset, base));
144 }
145 offset += slice;
146 }
147 None
148 }
149
150 fn config_at_kind_index(
152 &self,
153 kind: ProjectionKind,
154 mut idx: usize,
155 base: &PipelineConfig,
156 ) -> PipelineConfig {
157 let take = |idx: &mut usize, len: usize| -> usize {
158 let v = *idx % len;
159 *idx /= len;
160 v
161 };
162
163 let i_ndg = take(&mut idx, self.num_domain_groups.len());
164 let i_let = take(&mut idx, self.low_evr_threshold.len());
165 let i_oat = take(&mut idx, self.overlap_artifact_territorial.len());
166 let i_tb = take(&mut idx, self.threshold_base.len());
167 let i_tep = take(&mut idx, self.threshold_evr_penalty.len());
168 let i_mei = take(&mut idx, self.min_evr_improvement.len());
169
170 let mut cfg = base.clone();
171 cfg.projection_kind = kind;
172 cfg.routing = RoutingConfig {
173 num_domain_groups: self.num_domain_groups[i_ndg],
174 low_evr_threshold: self.low_evr_threshold[i_let],
175 };
176 cfg.bridges = BridgeConfig {
177 threshold_base: self.threshold_base[i_tb],
178 threshold_evr_penalty: self.threshold_evr_penalty[i_tep],
179 overlap_artifact_territorial: self.overlap_artifact_territorial[i_oat],
180 };
181 cfg.inner_sphere = InnerSphereConfig {
182 min_evr_improvement: self.min_evr_improvement[i_mei],
183 ..base.inner_sphere.clone()
184 };
185
186 if matches!(kind, ProjectionKind::LaplacianEigenmap) {
187 let i_k = take(&mut idx, self.laplacian_k_neighbors.len());
188 let i_thr = take(&mut idx, self.laplacian_active_threshold.len());
189 cfg.laplacian = LaplacianConfig {
190 k_neighbors: self.laplacian_k_neighbors[i_k],
191 active_threshold: self.laplacian_active_threshold[i_thr],
192 };
193 }
194
195 cfg
196 }
197
198 pub(crate) fn sample(&self, rng: &mut SplitMix64, base: &PipelineConfig) -> PipelineConfig {
204 let mut cfg = base.clone();
205 cfg.projection_kind = pick_uniform(rng, &self.projection_kinds);
206 cfg.routing = RoutingConfig {
207 num_domain_groups: pick_uniform(rng, &self.num_domain_groups),
208 low_evr_threshold: pick_uniform(rng, &self.low_evr_threshold),
209 };
210 cfg.bridges = BridgeConfig {
211 threshold_base: pick_uniform(rng, &self.threshold_base),
212 threshold_evr_penalty: pick_uniform(rng, &self.threshold_evr_penalty),
213 overlap_artifact_territorial: pick_uniform(rng, &self.overlap_artifact_territorial),
214 };
215 cfg.inner_sphere = InnerSphereConfig {
216 min_evr_improvement: pick_uniform(rng, &self.min_evr_improvement),
217 ..base.inner_sphere.clone()
218 };
219
220 if matches!(cfg.projection_kind, ProjectionKind::LaplacianEigenmap) {
221 cfg.laplacian = LaplacianConfig {
222 k_neighbors: pick_uniform(rng, &self.laplacian_k_neighbors),
223 active_threshold: pick_uniform(rng, &self.laplacian_active_threshold),
224 };
225 }
226
227 cfg
228 }
229}
230
231#[derive(Clone, PartialEq, Eq, Hash)]
241enum ProjectionFitKey {
242 Pca,
243 KernelPca,
244 Laplacian { k: usize, threshold_bits: u64 },
245}
246
247impl ProjectionFitKey {
248 fn from_config(cfg: &PipelineConfig) -> Self {
249 match cfg.projection_kind {
250 ProjectionKind::Pca => Self::Pca,
251 ProjectionKind::KernelPca => Self::KernelPca,
252 ProjectionKind::LaplacianEigenmap => Self::Laplacian {
253 k: cfg.laplacian.k_neighbors,
254 threshold_bits: cfg.laplacian.active_threshold.to_bits(),
255 },
256 }
257 }
258}
259
260#[derive(Debug, Clone)]
264pub enum SearchStrategy {
265 Grid,
268 Random { budget: usize, seed: u64 },
270 Bayesian {
281 budget: usize,
282 warmup: usize,
285 gamma: f64,
289 seed: u64,
290 },
291}
292
293#[derive(Debug, Clone)]
295pub struct TrialRecord {
296 pub config: PipelineConfig,
297 pub score: f64,
298 pub build_ms: u128,
301}
302
303#[derive(Debug, Clone)]
305pub struct TuneReport {
306 pub metric_name: String,
307 pub best_score: f64,
308 pub best_config: PipelineConfig,
309 pub trials: Vec<TrialRecord>,
310 pub failures: Vec<(PipelineConfig, String)>,
314}
315
316impl TuneReport {
317 pub fn ranked_trials(&self) -> Vec<&TrialRecord> {
319 let mut refs: Vec<&TrialRecord> = self.trials.iter().collect();
320 refs.sort_by(|a, b| {
321 b.score
322 .partial_cmp(&a.score)
323 .unwrap_or(std::cmp::Ordering::Equal)
324 });
325 refs
326 }
327
328 pub fn mean_score(&self) -> f64 {
332 if self.trials.is_empty() {
333 return 0.0;
334 }
335 self.trials.iter().map(|t| t.score).sum::<f64>() / self.trials.len() as f64
336 }
337}
338
339pub fn auto_tune<M: QualityMetric + ?Sized>(
351 input: PipelineInput,
352 space: &SearchSpace,
353 metric: &M,
354 strategy: SearchStrategy,
355 base_config: &PipelineConfig,
356) -> Result<(SphereQLPipeline, TuneReport), PipelineError> {
357 let categories = input.categories;
360 let embeddings: Vec<Embedding> = input.embeddings.into_iter().map(Embedding::new).collect();
361
362 let mut prefit: HashMap<ProjectionFitKey, ConfiguredProjection> = HashMap::new();
363 let mut trials: Vec<TrialRecord> = Vec::new();
364 let mut failures: Vec<(PipelineConfig, String)> = Vec::new();
365
366 let run_trial = |cfg: PipelineConfig,
370 prefit: &mut HashMap<ProjectionFitKey, ConfiguredProjection>,
371 trials: &mut Vec<TrialRecord>,
372 failures: &mut Vec<(PipelineConfig, String)>| {
373 let key = ProjectionFitKey::from_config(&cfg);
374 let projection = match prefit.get(&key) {
375 Some(p) => p.clone(),
376 None => match fit_projection_for_config(&embeddings, &cfg) {
377 Ok(p) => {
378 prefit.insert(key, p.clone());
379 p
380 }
381 Err(e) => {
382 failures.push((cfg, e.to_string()));
383 return;
384 }
385 },
386 };
387
388 let start = Instant::now();
389 match SphereQLPipeline::with_configured_projection_and_config(
390 categories.clone(),
391 embeddings.clone(),
392 projection,
393 cfg.clone(),
394 ) {
395 Ok(pipeline) => {
396 let score = metric.score(&pipeline);
397 let build_ms = start.elapsed().as_millis();
398 trials.push(TrialRecord {
399 config: cfg,
400 score,
401 build_ms,
402 });
403 }
404 Err(e) => {
405 failures.push((cfg, e.to_string()));
406 }
407 }
408 };
409
410 match &strategy {
411 SearchStrategy::Grid => {
412 for i in 0..space.grid_cardinality() {
413 if let Some(cfg) = space.config_at_index(i, base_config) {
414 run_trial(cfg, &mut prefit, &mut trials, &mut failures);
415 }
416 }
417 }
418 SearchStrategy::Random { budget, seed } => {
419 let mut rng = SplitMix64::new(*seed);
420 for _ in 0..*budget {
421 let cfg = space.sample(&mut rng, base_config);
422 run_trial(cfg, &mut prefit, &mut trials, &mut failures);
423 }
424 }
425 SearchStrategy::Bayesian {
426 budget,
427 warmup,
428 gamma,
429 seed,
430 } => {
431 let mut rng = SplitMix64::new(*seed);
432 let budget = *budget;
433 let warmup = (*warmup).clamp(2, budget);
434 let gamma = gamma.clamp(0.05, 0.95);
435
436 for _ in 0..warmup {
438 let cfg = space.sample(&mut rng, base_config);
439 run_trial(cfg, &mut prefit, &mut trials, &mut failures);
440 }
441 for _ in warmup..budget {
443 let cfg = tpe_propose(space, base_config, &trials, gamma, &mut rng);
444 run_trial(cfg, &mut prefit, &mut trials, &mut failures);
445 }
446 }
447 }
448
449 if trials.is_empty() {
450 return Err(PipelineError::AllTrialsFailed { failures });
454 }
455
456 let best_idx = trials
458 .iter()
459 .enumerate()
460 .max_by(|(_, a), (_, b)| {
461 a.score
462 .partial_cmp(&b.score)
463 .unwrap_or(std::cmp::Ordering::Equal)
464 })
465 .map(|(i, _)| i)
466 .expect("trials non-empty");
467 let best_config = trials[best_idx].config.clone();
468 let best_score = trials[best_idx].score;
469
470 let best_key = ProjectionFitKey::from_config(&best_config);
476 let best_projection = match prefit.get(&best_key).cloned() {
477 Some(p) => p,
478 None => fit_projection_for_config(&embeddings, &best_config)?,
479 };
480 let best_pipeline = SphereQLPipeline::with_configured_projection_and_config(
481 categories,
482 embeddings,
483 best_projection,
484 best_config.clone(),
485 )?;
486
487 let report = TuneReport {
488 metric_name: metric.name().to_string(),
489 best_score,
490 best_config,
491 trials,
492 failures,
493 };
494
495 Ok((best_pipeline, report))
496}
497
498fn tpe_propose(
513 space: &SearchSpace,
514 base: &PipelineConfig,
515 trials: &[TrialRecord],
516 gamma: f64,
517 rng: &mut SplitMix64,
518) -> PipelineConfig {
519 let mut sorted: Vec<&TrialRecord> = trials.iter().collect();
521 sorted.sort_by(|a, b| {
522 b.score
523 .partial_cmp(&a.score)
524 .unwrap_or(std::cmp::Ordering::Equal)
525 });
526 let n_good = ((sorted.len() as f64) * gamma).ceil() as usize;
527 let n_good = n_good.max(1).min(sorted.len().saturating_sub(1).max(1));
528 let good: Vec<&TrialRecord> = sorted.iter().take(n_good).copied().collect();
529 let bad: Vec<&TrialRecord> = sorted.iter().skip(n_good).copied().collect();
530
531 if good.is_empty() || bad.is_empty() {
533 return space.sample(rng, base);
534 }
535
536 let pick_idx = |rng: &mut SplitMix64, good_counts: &[f64], bad_counts: &[f64]| -> usize {
537 let n_g = good_counts.iter().sum::<f64>() + good_counts.len() as f64;
538 let n_b = bad_counts.iter().sum::<f64>() + bad_counts.len() as f64;
539 let weights: Vec<f64> = good_counts
540 .iter()
541 .zip(bad_counts.iter())
542 .map(|(&g, &b)| ((g + 1.0) / n_g) / ((b + 1.0) / n_b))
543 .collect();
544 sample_categorical(rng, &weights)
545 };
546
547 let pk_g = hist_kind(&good, &space.projection_kinds);
549 let pk_b = hist_kind(&bad, &space.projection_kinds);
550 let kind = space.projection_kinds[pick_idx(rng, &pk_g, &pk_b)];
551
552 let ndg_g = hist_usize(&good, &space.num_domain_groups, |c| {
554 c.routing.num_domain_groups
555 });
556 let ndg_b = hist_usize(&bad, &space.num_domain_groups, |c| {
557 c.routing.num_domain_groups
558 });
559 let let_g = hist_f64(&good, &space.low_evr_threshold, |c| {
560 c.routing.low_evr_threshold
561 });
562 let let_b = hist_f64(&bad, &space.low_evr_threshold, |c| {
563 c.routing.low_evr_threshold
564 });
565 let oat_g = hist_f64(&good, &space.overlap_artifact_territorial, |c| {
566 c.bridges.overlap_artifact_territorial
567 });
568 let oat_b = hist_f64(&bad, &space.overlap_artifact_territorial, |c| {
569 c.bridges.overlap_artifact_territorial
570 });
571 let tb_g = hist_f64(&good, &space.threshold_base, |c| c.bridges.threshold_base);
572 let tb_b = hist_f64(&bad, &space.threshold_base, |c| c.bridges.threshold_base);
573 let tep_g = hist_f64(&good, &space.threshold_evr_penalty, |c| {
574 c.bridges.threshold_evr_penalty
575 });
576 let tep_b = hist_f64(&bad, &space.threshold_evr_penalty, |c| {
577 c.bridges.threshold_evr_penalty
578 });
579 let mei_g = hist_f64(&good, &space.min_evr_improvement, |c| {
580 c.inner_sphere.min_evr_improvement
581 });
582 let mei_b = hist_f64(&bad, &space.min_evr_improvement, |c| {
583 c.inner_sphere.min_evr_improvement
584 });
585
586 let mut cfg = base.clone();
587 cfg.projection_kind = kind;
588 cfg.routing = RoutingConfig {
589 num_domain_groups: space.num_domain_groups[pick_idx(rng, &ndg_g, &ndg_b)],
590 low_evr_threshold: space.low_evr_threshold[pick_idx(rng, &let_g, &let_b)],
591 };
592 cfg.bridges = BridgeConfig {
593 threshold_base: space.threshold_base[pick_idx(rng, &tb_g, &tb_b)],
594 threshold_evr_penalty: space.threshold_evr_penalty[pick_idx(rng, &tep_g, &tep_b)],
595 overlap_artifact_territorial: space.overlap_artifact_territorial
596 [pick_idx(rng, &oat_g, &oat_b)],
597 };
598 cfg.inner_sphere = InnerSphereConfig {
599 min_evr_improvement: space.min_evr_improvement[pick_idx(rng, &mei_g, &mei_b)],
600 ..base.inner_sphere.clone()
601 };
602
603 if matches!(kind, ProjectionKind::LaplacianEigenmap) {
605 let good_l: Vec<&TrialRecord> = good
606 .iter()
607 .copied()
608 .filter(|t| t.config.projection_kind == ProjectionKind::LaplacianEigenmap)
609 .collect();
610 let bad_l: Vec<&TrialRecord> = bad
611 .iter()
612 .copied()
613 .filter(|t| t.config.projection_kind == ProjectionKind::LaplacianEigenmap)
614 .collect();
615 if good_l.is_empty() || bad_l.is_empty() {
616 cfg.laplacian = LaplacianConfig {
618 k_neighbors: space.laplacian_k_neighbors
619 [(rng.next_u64() as usize) % space.laplacian_k_neighbors.len()],
620 active_threshold: space.laplacian_active_threshold
621 [(rng.next_u64() as usize) % space.laplacian_active_threshold.len()],
622 };
623 } else {
624 let k_g = hist_usize(&good_l, &space.laplacian_k_neighbors, |c| {
625 c.laplacian.k_neighbors
626 });
627 let k_b = hist_usize(&bad_l, &space.laplacian_k_neighbors, |c| {
628 c.laplacian.k_neighbors
629 });
630 let at_g = hist_f64(&good_l, &space.laplacian_active_threshold, |c| {
631 c.laplacian.active_threshold
632 });
633 let at_b = hist_f64(&bad_l, &space.laplacian_active_threshold, |c| {
634 c.laplacian.active_threshold
635 });
636 cfg.laplacian = LaplacianConfig {
637 k_neighbors: space.laplacian_k_neighbors[pick_idx(rng, &k_g, &k_b)],
638 active_threshold: space.laplacian_active_threshold[pick_idx(rng, &at_g, &at_b)],
639 };
640 }
641 }
642
643 cfg
644}
645
646fn hist_kind(trials: &[&TrialRecord], values: &[ProjectionKind]) -> Vec<f64> {
647 let mut counts = vec![0.0f64; values.len()];
648 for t in trials {
649 if let Some(i) = values.iter().position(|&v| v == t.config.projection_kind) {
650 counts[i] += 1.0;
651 }
652 }
653 counts
654}
655
656fn hist_usize(
657 trials: &[&TrialRecord],
658 values: &[usize],
659 extract: impl Fn(&PipelineConfig) -> usize,
660) -> Vec<f64> {
661 let mut counts = vec![0.0f64; values.len()];
662 for t in trials {
663 let v = extract(&t.config);
664 if let Some(i) = values.iter().position(|&x| x == v) {
665 counts[i] += 1.0;
666 }
667 }
668 counts
669}
670
671fn hist_f64(
676 trials: &[&TrialRecord],
677 values: &[f64],
678 extract: impl Fn(&PipelineConfig) -> f64,
679) -> Vec<f64> {
680 let mut counts = vec![0.0f64; values.len()];
681 for t in trials {
682 let v = extract(&t.config);
683 if let Some((i, _)) = values.iter().enumerate().min_by(|a, b| {
684 (a.1 - v)
685 .abs()
686 .partial_cmp(&(b.1 - v).abs())
687 .unwrap_or(std::cmp::Ordering::Equal)
688 }) {
689 counts[i] += 1.0;
690 }
691 }
692 counts
693}
694
695fn pick_uniform<T: Copy>(rng: &mut SplitMix64, vals: &[T]) -> T {
700 vals[(rng.next_u64() as usize) % vals.len()]
701}
702
703fn sample_categorical(rng: &mut SplitMix64, weights: &[f64]) -> usize {
704 let total: f64 = weights.iter().sum();
705 if total <= 0.0 || !total.is_finite() {
706 return (rng.next_u64() as usize) % weights.len().max(1);
707 }
708 let r = rng.next_f64() * total;
709 let mut acc = 0.0;
710 for (i, &w) in weights.iter().enumerate() {
711 acc += w;
712 if r <= acc {
713 return i;
714 }
715 }
716 weights.len() - 1
717}
718
719#[cfg(test)]
722mod tests {
723 use super::*;
724 use crate::quality_metric::{BridgeCoherence, CompositeMetric, TerritorialHealth};
725
726 fn make_input(n: usize, dim: usize) -> PipelineInput {
727 let mut embeddings = Vec::new();
728 let mut categories = Vec::new();
729 for i in 0..n {
730 let mut v = vec![0.0; dim];
731 if i < n / 3 {
732 v[0] = 1.0 + (i as f64 * 0.01);
733 v[1] = 0.1;
734 categories.push("one".into());
735 } else if i < 2 * n / 3 {
736 v[2] = 1.0 + (i as f64 * 0.01);
737 v[3] = 0.1;
738 categories.push("two".into());
739 } else {
740 v[4] = 1.0 + (i as f64 * 0.01);
741 v[5] = 0.1;
742 categories.push("three".into());
743 }
744 v[6] = 0.02 * i as f64;
745 embeddings.push(v);
746 }
747 PipelineInput {
748 categories,
749 embeddings,
750 }
751 }
752
753 #[test]
754 fn search_space_grid_cardinality_sums_per_kind() {
755 let s = SearchSpace::default();
756 let common = s.num_domain_groups.len()
757 * s.low_evr_threshold.len()
758 * s.overlap_artifact_territorial.len()
759 * s.threshold_base.len()
760 * s.threshold_evr_penalty.len()
761 * s.min_evr_improvement.len();
762 let expected =
765 common + common * s.laplacian_k_neighbors.len() * s.laplacian_active_threshold.len();
766 assert_eq!(s.grid_cardinality(), expected);
767 }
768
769 #[test]
770 fn default_search_space_includes_pca_and_laplacian() {
771 let s = SearchSpace::default();
772 assert!(s.projection_kinds.contains(&ProjectionKind::Pca));
773 assert!(
774 s.projection_kinds
775 .contains(&ProjectionKind::LaplacianEigenmap)
776 );
777 assert!(!s.projection_kinds.contains(&ProjectionKind::KernelPca));
779 }
780
781 #[test]
782 fn grid_index_enumerates_full_space() {
783 let s = SearchSpace {
784 projection_kinds: vec![ProjectionKind::Pca],
785 laplacian_k_neighbors: vec![15],
786 laplacian_active_threshold: vec![0.05],
787 num_domain_groups: vec![3, 5],
788 low_evr_threshold: vec![0.3, 0.4],
789 overlap_artifact_territorial: vec![0.3],
790 threshold_base: vec![0.5],
791 threshold_evr_penalty: vec![0.4],
792 min_evr_improvement: vec![0.10],
793 };
794 let base = PipelineConfig::default();
795 let n = s.grid_cardinality();
796 let mut seen = std::collections::HashSet::new();
797 for i in 0..n {
798 let cfg = s.config_at_index(i, &base).unwrap();
799 let key = (
800 cfg.routing.num_domain_groups,
801 (cfg.routing.low_evr_threshold * 1000.0) as i64,
802 );
803 seen.insert(key);
804 }
805 assert_eq!(seen.len(), n);
806 assert!(s.config_at_index(n, &base).is_none());
807 }
808
809 #[test]
810 fn grid_index_enumerates_across_projection_kinds() {
811 let s = SearchSpace {
812 projection_kinds: vec![ProjectionKind::Pca, ProjectionKind::LaplacianEigenmap],
813 laplacian_k_neighbors: vec![15],
814 laplacian_active_threshold: vec![0.05],
815 num_domain_groups: vec![3],
816 low_evr_threshold: vec![0.35],
817 overlap_artifact_territorial: vec![0.3],
818 threshold_base: vec![0.5],
819 threshold_evr_penalty: vec![0.4],
820 min_evr_improvement: vec![0.10],
821 };
822 let base = PipelineConfig::default();
823 let kinds: std::collections::HashSet<ProjectionKind> = (0..s.grid_cardinality())
824 .map(|i| s.config_at_index(i, &base).unwrap().projection_kind)
825 .collect();
826 assert_eq!(kinds.len(), 2);
827 assert!(kinds.contains(&ProjectionKind::Pca));
828 assert!(kinds.contains(&ProjectionKind::LaplacianEigenmap));
829 }
830
831 #[test]
832 fn grid_search_runs_and_picks_best() {
833 let input = make_input(24, 8);
834 let space = SearchSpace {
835 projection_kinds: vec![ProjectionKind::Pca],
836 laplacian_k_neighbors: vec![15],
837 laplacian_active_threshold: vec![0.05],
838 num_domain_groups: vec![3, 5],
839 low_evr_threshold: vec![0.35],
840 overlap_artifact_territorial: vec![0.3],
841 threshold_base: vec![0.5],
842 threshold_evr_penalty: vec![0.4],
843 min_evr_improvement: vec![0.10],
844 };
845 let metric = TerritorialHealth;
846 let (pipeline, report) = auto_tune(
847 input,
848 &space,
849 &metric,
850 SearchStrategy::Grid,
851 &PipelineConfig::default(),
852 )
853 .unwrap();
854
855 assert_eq!(report.trials.len(), 2);
856 assert!(report.best_score >= report.mean_score() - 1e-9);
857 assert!(pipeline.num_categories() > 0);
858 assert_eq!(report.metric_name, "territorial_health");
859 assert!(report.failures.is_empty());
860 }
861
862 #[test]
863 fn random_search_respects_budget() {
864 let input = make_input(24, 8);
865 let space = SearchSpace::default();
866 let metric = BridgeCoherence;
867 let (_pipeline, report) = auto_tune(
868 input,
869 &space,
870 &metric,
871 SearchStrategy::Random {
872 budget: 5,
873 seed: 42,
874 },
875 &PipelineConfig::default(),
876 )
877 .unwrap();
878 assert_eq!(report.trials.len(), 5);
879 }
880
881 #[test]
882 fn random_search_is_seed_reproducible() {
883 let space = SearchSpace::default();
884 let metric = TerritorialHealth;
885
886 let run = |seed: u64| {
887 let input = make_input(24, 8);
888 auto_tune(
889 input,
890 &space,
891 &metric,
892 SearchStrategy::Random { budget: 8, seed },
893 &PipelineConfig::default(),
894 )
895 .unwrap()
896 .1
897 };
898
899 let a = run(7);
900 let b = run(7);
901 let c = run(13);
902
903 assert_eq!(a.trials.len(), b.trials.len());
904 for (ta, tb) in a.trials.iter().zip(b.trials.iter()) {
905 assert_eq!(
906 ta.config.routing.num_domain_groups,
907 tb.config.routing.num_domain_groups
908 );
909 assert!((ta.score - tb.score).abs() < 1e-12);
910 }
911 let any_differ = a.trials.iter().zip(c.trials.iter()).any(|(ta, tc)| {
915 ta.config.routing.num_domain_groups != tc.config.routing.num_domain_groups
916 || (ta.config.bridges.threshold_base - tc.config.bridges.threshold_base).abs()
917 > 1e-12
918 });
919 assert!(any_differ, "different seeds produced identical trial set");
920 }
921
922 #[test]
923 fn ranked_trials_are_descending() {
924 let input = make_input(24, 8);
925 let metric = CompositeMetric::default_composite();
926 let (_p, report) = auto_tune(
927 input,
928 &SearchSpace::default(),
929 &metric,
930 SearchStrategy::Random {
931 budget: 6,
932 seed: 99,
933 },
934 &PipelineConfig::default(),
935 )
936 .unwrap();
937 let ranked = report.ranked_trials();
938 for w in ranked.windows(2) {
939 assert!(w[0].score >= w[1].score);
940 }
941 }
942
943 #[test]
944 fn best_config_actually_in_trials() {
945 let input = make_input(24, 8);
946 let metric = TerritorialHealth;
947 let (_p, report) = auto_tune(
948 input,
949 &SearchSpace::default(),
950 &metric,
951 SearchStrategy::Random { budget: 4, seed: 1 },
952 &PipelineConfig::default(),
953 )
954 .unwrap();
955 let any_match = report.trials.iter().any(|t| {
956 t.config.routing.num_domain_groups == report.best_config.routing.num_domain_groups
957 && (t.config.routing.low_evr_threshold
958 - report.best_config.routing.low_evr_threshold)
959 .abs()
960 < 1e-12
961 && (t.score - report.best_score).abs() < 1e-12
962 });
963 assert!(any_match, "best_config must appear in trials");
964 }
965
966 #[test]
967 fn grid_search_across_projection_kinds_yields_both() {
968 let input = make_input(24, 8);
969 let space = SearchSpace {
970 projection_kinds: vec![ProjectionKind::Pca, ProjectionKind::LaplacianEigenmap],
971 laplacian_k_neighbors: vec![10, 20],
972 laplacian_active_threshold: vec![0.05],
973 num_domain_groups: vec![3],
974 low_evr_threshold: vec![0.35],
975 overlap_artifact_territorial: vec![0.3],
976 threshold_base: vec![0.5],
977 threshold_evr_penalty: vec![0.4],
978 min_evr_improvement: vec![0.10],
979 };
980 let metric = TerritorialHealth;
981 let (_pipeline, report) = auto_tune(
982 input,
983 &space,
984 &metric,
985 SearchStrategy::Grid,
986 &PipelineConfig::default(),
987 )
988 .unwrap();
989 assert_eq!(report.trials.len(), 3);
992 let kinds_in_trials: std::collections::HashSet<ProjectionKind> = report
993 .trials
994 .iter()
995 .map(|t| t.config.projection_kind)
996 .collect();
997 assert!(kinds_in_trials.contains(&ProjectionKind::Pca));
998 assert!(kinds_in_trials.contains(&ProjectionKind::LaplacianEigenmap));
999 let lap_ks: std::collections::HashSet<usize> = report
1001 .trials
1002 .iter()
1003 .filter(|t| t.config.projection_kind == ProjectionKind::LaplacianEigenmap)
1004 .map(|t| t.config.laplacian.k_neighbors)
1005 .collect();
1006 assert_eq!(lap_ks.len(), 2);
1007 }
1008
1009 #[test]
1010 fn laplacian_knobs_produce_distinct_configs() {
1011 let s = SearchSpace {
1015 projection_kinds: vec![ProjectionKind::LaplacianEigenmap],
1016 laplacian_k_neighbors: vec![10, 20],
1017 laplacian_active_threshold: vec![0.03, 0.08],
1018 num_domain_groups: vec![3],
1019 low_evr_threshold: vec![0.35],
1020 overlap_artifact_territorial: vec![0.3],
1021 threshold_base: vec![0.5],
1022 threshold_evr_penalty: vec![0.4],
1023 min_evr_improvement: vec![0.10],
1024 };
1025 let base = PipelineConfig::default();
1026 let configs: Vec<(usize, u64)> = (0..s.grid_cardinality())
1027 .map(|i| {
1028 let cfg = s.config_at_index(i, &base).unwrap();
1029 (
1030 cfg.laplacian.k_neighbors,
1031 cfg.laplacian.active_threshold.to_bits(),
1032 )
1033 })
1034 .collect();
1035 let unique: std::collections::HashSet<(usize, u64)> = configs.iter().copied().collect();
1036 assert_eq!(unique.len(), 4, "expected 4 distinct (k, threshold) pairs");
1037 }
1038
1039 #[test]
1040 fn bayesian_respects_budget() {
1041 let input = make_input(24, 8);
1042 let metric = TerritorialHealth;
1043 let (_p, report) = auto_tune(
1044 input,
1045 &SearchSpace::default(),
1046 &metric,
1047 SearchStrategy::Bayesian {
1048 budget: 10,
1049 warmup: 4,
1050 gamma: 0.25,
1051 seed: 42,
1052 },
1053 &PipelineConfig::default(),
1054 )
1055 .unwrap();
1056 assert_eq!(report.trials.len(), 10);
1057 }
1058
1059 #[test]
1060 fn bayesian_seed_reproducible() {
1061 let metric = TerritorialHealth;
1062 let run = |seed: u64| {
1063 let input = make_input(24, 8);
1064 auto_tune(
1065 input,
1066 &SearchSpace::default(),
1067 &metric,
1068 SearchStrategy::Bayesian {
1069 budget: 8,
1070 warmup: 3,
1071 gamma: 0.25,
1072 seed,
1073 },
1074 &PipelineConfig::default(),
1075 )
1076 .unwrap()
1077 .1
1078 };
1079 let a = run(7);
1080 let b = run(7);
1081 assert_eq!(a.trials.len(), b.trials.len());
1082 for (ta, tb) in a.trials.iter().zip(b.trials.iter()) {
1083 assert_eq!(ta.config.projection_kind, tb.config.projection_kind);
1084 assert!((ta.score - tb.score).abs() < 1e-12);
1085 }
1086 }
1087
1088 #[test]
1089 fn bayesian_finds_something_under_default_metric() {
1090 let input = make_input(30, 10);
1094 let metric = CompositeMetric::default_composite();
1095 let (_p, report) = auto_tune(
1096 input,
1097 &SearchSpace::default(),
1098 &metric,
1099 SearchStrategy::Bayesian {
1100 budget: 12,
1101 warmup: 4,
1102 gamma: 0.25,
1103 seed: 0xC0FFEE,
1104 },
1105 &PipelineConfig::default(),
1106 )
1107 .unwrap();
1108 assert_eq!(report.trials.len(), 12);
1109 assert!(report.best_score >= 0.0 && report.best_score <= 1.0);
1110 }
1111
1112 #[test]
1113 fn bayesian_warmup_clamped() {
1114 let input = make_input(24, 8);
1116 let metric = TerritorialHealth;
1117 let (_p, report) = auto_tune(
1118 input,
1119 &SearchSpace::default(),
1120 &metric,
1121 SearchStrategy::Bayesian {
1122 budget: 5,
1123 warmup: 100,
1124 gamma: 0.25,
1125 seed: 1,
1126 },
1127 &PipelineConfig::default(),
1128 )
1129 .unwrap();
1130 assert_eq!(report.trials.len(), 5);
1131 }
1132
1133 #[test]
1134 fn returned_pipeline_uses_best_config() {
1135 let input = make_input(24, 8);
1136 let metric = TerritorialHealth;
1137 let (pipeline, report) = auto_tune(
1138 input,
1139 &SearchSpace::default(),
1140 &metric,
1141 SearchStrategy::Random {
1142 budget: 4,
1143 seed: 11,
1144 },
1145 &PipelineConfig::default(),
1146 )
1147 .unwrap();
1148 assert_eq!(
1149 pipeline.config().routing.num_domain_groups,
1150 report.best_config.routing.num_domain_groups
1151 );
1152 assert_eq!(
1153 pipeline.projection_kind(),
1154 report.best_config.projection_kind
1155 );
1156 }
1157}