1use std::collections::HashMap;
23use std::time::Instant;
24
25use crate::config::{
26 BridgeConfig, InnerSphereConfig, LaplacianConfig, PipelineConfig, ProjectionKind,
27 RoutingConfig, UmapConfig,
28};
29use crate::configured_projection::ConfiguredProjection;
30use crate::pipeline::{
31 PipelineError, PipelineInput, SphereQLPipeline, fit_projection_for_config, fit_umap_from_graph,
32};
33use crate::projection::SplitMix64;
34use crate::quality_metric::QualityMetric;
35use crate::types::Embedding;
36
37#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
49pub struct SearchSpace {
50 pub projection_kinds: Vec<ProjectionKind>,
54
55 pub laplacian_k_neighbors: Vec<usize>,
64 pub laplacian_active_threshold: Vec<f64>,
68
69 pub umap_n_neighbors: Vec<usize>,
72 pub umap_n_epochs: Vec<usize>,
75 pub umap_category_weight: Vec<f64>,
79 pub umap_min_dist: Vec<f64>,
82
83 pub num_domain_groups: Vec<usize>,
86 pub low_evr_threshold: Vec<f64>,
88 pub overlap_artifact_territorial: Vec<f64>,
90 pub threshold_base: Vec<f64>,
92 pub threshold_evr_penalty: Vec<f64>,
94 pub min_evr_improvement: Vec<f64>,
96}
97
98impl SearchSpace {
99 pub fn large_corpus() -> Self {
105 Self {
106 projection_kinds: vec![ProjectionKind::Pca, ProjectionKind::UmapSphere],
107 laplacian_k_neighbors: vec![15],
111 laplacian_active_threshold: vec![0.05],
112 umap_n_neighbors: vec![10, 15, 30],
113 umap_n_epochs: vec![150, 300],
114 umap_category_weight: vec![0.0, 1.5, 3.0],
115 umap_min_dist: vec![0.0, 0.1, 0.25],
116 num_domain_groups: vec![3, 5, 7],
117 low_evr_threshold: vec![0.25, 0.35],
118 overlap_artifact_territorial: vec![0.2, 0.3],
119 threshold_base: vec![0.4, 0.5],
120 threshold_evr_penalty: vec![0.3, 0.5],
121 min_evr_improvement: vec![0.05, 0.10],
122 }
123 }
124}
125
126impl Default for SearchSpace {
127 fn default() -> Self {
128 Self {
129 projection_kinds: vec![ProjectionKind::Pca, ProjectionKind::LaplacianEigenmap],
133 laplacian_k_neighbors: vec![10, 15, 25],
137 laplacian_active_threshold: vec![0.03, 0.05, 0.10],
138 umap_n_neighbors: vec![10, 15, 30],
139 umap_n_epochs: vec![150, 250],
140 umap_category_weight: vec![0.0, 1.5, 3.0],
141 umap_min_dist: vec![0.0, 0.1, 0.25],
142 num_domain_groups: vec![3, 5, 7],
143 low_evr_threshold: vec![0.25, 0.35, 0.45],
144 overlap_artifact_territorial: vec![0.2, 0.3, 0.4],
145 threshold_base: vec![0.4, 0.5, 0.6],
146 threshold_evr_penalty: vec![0.2, 0.4, 0.6],
147 min_evr_improvement: vec![0.05, 0.10, 0.15],
148 }
149 }
150}
151
152impl SearchSpace {
153 pub fn validate(&self, strategy: &SearchStrategy) -> Result<(), PipelineError> {
161 if self.projection_kinds.is_empty() {
162 return Err(PipelineError::InvalidSearchSpace(
163 "axis `projection_kinds` is empty".into(),
164 ));
165 }
166 for &kind in &self.projection_kinds {
167 self.check_axes_non_empty(kind)?;
168 }
169 match strategy {
170 SearchStrategy::Grid => {}
171 SearchStrategy::Random { budget, .. } => {
172 if *budget == 0 {
173 return Err(PipelineError::InvalidSearchSpace(
174 "Random search requires budget >= 1".into(),
175 ));
176 }
177 }
178 SearchStrategy::Bayesian {
179 budget,
180 warmup,
181 gamma,
182 ..
183 } => {
184 if *budget < 2 {
185 return Err(PipelineError::InvalidSearchSpace(format!(
186 "Bayesian search requires budget >= 2 (got {budget})"
187 )));
188 }
189 if *warmup < 2 {
190 return Err(PipelineError::InvalidSearchSpace(format!(
191 "Bayesian search requires warmup >= 2 (got {warmup})"
192 )));
193 }
194 if !gamma.is_finite() || *gamma <= 0.0 || *gamma >= 1.0 {
195 return Err(PipelineError::InvalidSearchSpace(format!(
196 "Bayesian gamma must be in (0, 1), got {gamma}"
197 )));
198 }
199 }
200 }
201 Ok(())
202 }
203
204 fn check_axes_non_empty(&self, kind: ProjectionKind) -> Result<(), PipelineError> {
205 let common = [
206 ("num_domain_groups", self.num_domain_groups.len()),
207 ("low_evr_threshold", self.low_evr_threshold.len()),
208 (
209 "overlap_artifact_territorial",
210 self.overlap_artifact_territorial.len(),
211 ),
212 ("threshold_base", self.threshold_base.len()),
213 ("threshold_evr_penalty", self.threshold_evr_penalty.len()),
214 ("min_evr_improvement", self.min_evr_improvement.len()),
215 ];
216 for (name, len) in common {
217 if len == 0 {
218 return Err(PipelineError::InvalidSearchSpace(format!(
219 "axis `{name}` is empty"
220 )));
221 }
222 }
223 if matches!(kind, ProjectionKind::LaplacianEigenmap) {
224 if self.laplacian_k_neighbors.is_empty() {
225 return Err(PipelineError::InvalidSearchSpace(
226 "axis `laplacian_k_neighbors` is empty".into(),
227 ));
228 }
229 if self.laplacian_active_threshold.is_empty() {
230 return Err(PipelineError::InvalidSearchSpace(
231 "axis `laplacian_active_threshold` is empty".into(),
232 ));
233 }
234 }
235 if matches!(kind, ProjectionKind::UmapSphere) {
236 if self.umap_n_neighbors.is_empty() {
237 return Err(PipelineError::InvalidSearchSpace(
238 "axis `umap_n_neighbors` is empty".into(),
239 ));
240 }
241 if self.umap_n_epochs.is_empty() {
242 return Err(PipelineError::InvalidSearchSpace(
243 "axis `umap_n_epochs` is empty".into(),
244 ));
245 }
246 if self.umap_category_weight.is_empty() {
247 return Err(PipelineError::InvalidSearchSpace(
248 "axis `umap_category_weight` is empty".into(),
249 ));
250 }
251 if self.umap_min_dist.is_empty() {
252 return Err(PipelineError::InvalidSearchSpace(
253 "axis `umap_min_dist` is empty".into(),
254 ));
255 }
256 }
257 Ok(())
258 }
259
260 fn assert_axes_non_empty(&self, kind: ProjectionKind) {
266 if let Err(e) = self.check_axes_non_empty(kind) {
267 panic!("{e}");
268 }
269 }
270
271 fn common_cardinality(&self) -> usize {
275 self.num_domain_groups.len()
276 * self.low_evr_threshold.len()
277 * self.overlap_artifact_territorial.len()
278 * self.threshold_base.len()
279 * self.threshold_evr_penalty.len()
280 * self.min_evr_improvement.len()
281 }
282
283 fn kind_cardinality(&self, kind: ProjectionKind) -> usize {
286 let common = self.common_cardinality();
287 match kind {
288 ProjectionKind::LaplacianEigenmap => {
289 common * self.laplacian_k_neighbors.len() * self.laplacian_active_threshold.len()
290 }
291 ProjectionKind::UmapSphere => {
292 common
293 * self.umap_n_neighbors.len()
294 * self.umap_n_epochs.len()
295 * self.umap_category_weight.len()
296 * self.umap_min_dist.len()
297 }
298 ProjectionKind::Pca | ProjectionKind::KernelPca => common,
299 }
300 }
301
302 pub fn grid_cardinality(&self) -> usize {
305 self.projection_kinds
306 .iter()
307 .map(|&k| self.kind_cardinality(k))
308 .sum()
309 }
310
311 pub fn config_at_index(&self, index: usize, base: &PipelineConfig) -> Option<PipelineConfig> {
319 let mut offset = 0usize;
320 for &kind in &self.projection_kinds {
321 self.assert_axes_non_empty(kind);
322 let slice = self.kind_cardinality(kind);
323 if index < offset + slice {
324 return Some(self.config_at_kind_index(kind, index - offset, base));
325 }
326 offset += slice;
327 }
328 None
329 }
330
331 fn config_at_kind_index(
333 &self,
334 kind: ProjectionKind,
335 mut idx: usize,
336 base: &PipelineConfig,
337 ) -> PipelineConfig {
338 let take = |idx: &mut usize, len: usize| -> usize {
339 let v = *idx % len;
340 *idx /= len;
341 v
342 };
343
344 let i_ndg = take(&mut idx, self.num_domain_groups.len());
345 let i_let = take(&mut idx, self.low_evr_threshold.len());
346 let i_oat = take(&mut idx, self.overlap_artifact_territorial.len());
347 let i_tb = take(&mut idx, self.threshold_base.len());
348 let i_tep = take(&mut idx, self.threshold_evr_penalty.len());
349 let i_mei = take(&mut idx, self.min_evr_improvement.len());
350
351 let mut cfg = base.clone();
352 cfg.projection_kind = kind;
353 cfg.routing = RoutingConfig {
354 num_domain_groups: self.num_domain_groups[i_ndg],
355 low_evr_threshold: self.low_evr_threshold[i_let],
356 ..base.routing.clone()
357 };
358 cfg.bridges = BridgeConfig {
359 threshold_base: self.threshold_base[i_tb],
360 threshold_evr_penalty: self.threshold_evr_penalty[i_tep],
361 overlap_artifact_territorial: self.overlap_artifact_territorial[i_oat],
362 ..base.bridges.clone()
363 };
364 cfg.inner_sphere = InnerSphereConfig {
365 min_evr_improvement: self.min_evr_improvement[i_mei],
366 ..base.inner_sphere.clone()
367 };
368
369 if matches!(kind, ProjectionKind::LaplacianEigenmap) {
370 let i_k = take(&mut idx, self.laplacian_k_neighbors.len());
371 let i_thr = take(&mut idx, self.laplacian_active_threshold.len());
372 cfg.laplacian = LaplacianConfig {
373 k_neighbors: self.laplacian_k_neighbors[i_k],
374 active_threshold: self.laplacian_active_threshold[i_thr],
375 };
376 }
377
378 if matches!(kind, ProjectionKind::UmapSphere) {
379 let i_nn = take(&mut idx, self.umap_n_neighbors.len());
380 let i_ne = take(&mut idx, self.umap_n_epochs.len());
381 let i_cw = take(&mut idx, self.umap_category_weight.len());
382 let i_md = take(&mut idx, self.umap_min_dist.len());
383 cfg.umap = UmapConfig {
384 n_neighbors: self.umap_n_neighbors[i_nn],
385 n_epochs: self.umap_n_epochs[i_ne],
386 category_weight: self.umap_category_weight[i_cw],
387 min_dist: self.umap_min_dist[i_md],
388 ..base.umap.clone()
389 };
390 }
391
392 cfg
393 }
394
395 pub(crate) fn sample(&self, rng: &mut SplitMix64, base: &PipelineConfig) -> PipelineConfig {
401 debug_assert!(
407 !self.projection_kinds.is_empty(),
408 "SearchSpace::sample called without prior validate()"
409 );
410 let mut cfg = base.clone();
411 cfg.projection_kind = pick_uniform(rng, &self.projection_kinds);
412 debug_assert!(
413 self.check_axes_non_empty(cfg.projection_kind).is_ok(),
414 "SearchSpace::sample called without prior validate()"
415 );
416 cfg.routing = RoutingConfig {
417 num_domain_groups: pick_uniform(rng, &self.num_domain_groups),
418 low_evr_threshold: pick_uniform(rng, &self.low_evr_threshold),
419 ..base.routing.clone()
420 };
421 cfg.bridges = BridgeConfig {
422 threshold_base: pick_uniform(rng, &self.threshold_base),
423 threshold_evr_penalty: pick_uniform(rng, &self.threshold_evr_penalty),
424 overlap_artifact_territorial: pick_uniform(rng, &self.overlap_artifact_territorial),
425 ..base.bridges.clone()
426 };
427 cfg.inner_sphere = InnerSphereConfig {
428 min_evr_improvement: pick_uniform(rng, &self.min_evr_improvement),
429 ..base.inner_sphere.clone()
430 };
431
432 if matches!(cfg.projection_kind, ProjectionKind::LaplacianEigenmap) {
433 cfg.laplacian = LaplacianConfig {
434 k_neighbors: pick_uniform(rng, &self.laplacian_k_neighbors),
435 active_threshold: pick_uniform(rng, &self.laplacian_active_threshold),
436 };
437 }
438
439 if matches!(cfg.projection_kind, ProjectionKind::UmapSphere) {
440 cfg.umap = UmapConfig {
441 n_neighbors: pick_uniform(rng, &self.umap_n_neighbors),
442 n_epochs: pick_uniform(rng, &self.umap_n_epochs),
443 category_weight: pick_uniform(rng, &self.umap_category_weight),
444 min_dist: pick_uniform(rng, &self.umap_min_dist),
445 ..base.umap.clone()
446 };
447 }
448
449 cfg
450 }
451}
452
453#[derive(Clone, PartialEq, Eq, Hash)]
463enum ProjectionFitKey {
464 Pca,
465 KernelPca,
466 Laplacian {
467 k: usize,
468 threshold_bits: u64,
469 },
470 UmapSphere {
471 n_neighbors: usize,
472 n_epochs: usize,
473 category_weight_bits: u64,
474 min_dist_bits: u64,
475 },
476}
477
478impl ProjectionFitKey {
479 fn from_config(cfg: &PipelineConfig) -> Self {
480 match cfg.projection_kind {
481 ProjectionKind::Pca => Self::Pca,
482 ProjectionKind::KernelPca => Self::KernelPca,
483 ProjectionKind::LaplacianEigenmap => Self::Laplacian {
484 k: cfg.laplacian.k_neighbors,
485 threshold_bits: cfg.laplacian.active_threshold.to_bits(),
486 },
487 ProjectionKind::UmapSphere => Self::UmapSphere {
488 n_neighbors: cfg.umap.n_neighbors,
489 n_epochs: cfg.umap.n_epochs,
490 category_weight_bits: cfg.umap.category_weight.to_bits(),
491 min_dist_bits: cfg.umap.min_dist.to_bits(),
496 },
497 }
498 }
499}
500
501#[derive(Debug, Clone)]
505pub enum SearchStrategy {
506 Grid,
509 Random {
511 budget: usize,
512 seed: u64,
513 max_wall_secs: Option<u64>,
518 },
519 Bayesian {
530 budget: usize,
531 warmup: usize,
534 gamma: f64,
538 seed: u64,
539 max_wall_secs: Option<u64>,
542 },
543}
544
545impl SearchStrategy {
546 fn max_wall_secs(&self) -> Option<u64> {
548 match self {
549 Self::Random { max_wall_secs, .. } => *max_wall_secs,
550 Self::Bayesian { max_wall_secs, .. } => *max_wall_secs,
551 Self::Grid => None,
552 }
553 }
554}
555
556#[derive(Debug, Clone)]
558pub struct TrialRecord {
559 pub config: PipelineConfig,
560 pub score: f64,
561 pub build_ms: u128,
564 pub components: Vec<(String, f64, f64)>,
571}
572
573#[derive(Debug, Clone)]
575pub struct TuneReport {
576 pub metric_name: String,
577 pub best_score: f64,
578 pub best_config: PipelineConfig,
579 pub trials: Vec<TrialRecord>,
580 pub failures: Vec<(PipelineConfig, String)>,
584 pub umap_graph_builds: usize,
590}
591
592impl TuneReport {
593 pub fn ranked_trials(&self) -> Vec<&TrialRecord> {
595 let mut refs: Vec<&TrialRecord> = self.trials.iter().collect();
596 refs.sort_by(|a, b| {
597 b.score
598 .partial_cmp(&a.score)
599 .unwrap_or(std::cmp::Ordering::Equal)
600 });
601 refs
602 }
603
604 pub fn mean_score(&self) -> f64 {
608 if self.trials.is_empty() {
609 return 0.0;
610 }
611 self.trials.iter().map(|t| t.score).sum::<f64>() / self.trials.len() as f64
612 }
613}
614
615pub fn auto_tune<M: QualityMetric + ?Sized>(
634 input: PipelineInput,
635 space: &SearchSpace,
636 metric: &M,
637 strategy: SearchStrategy,
638 base_config: &PipelineConfig,
639) -> Result<(SphereQLPipeline, TuneReport), PipelineError> {
640 space.validate(&strategy)?;
645
646 let categories = input.categories;
649 let embeddings: Vec<Embedding> = input.embeddings.into_iter().map(Embedding::new).collect();
650
651 let mut prefit: HashMap<ProjectionFitKey, ConfiguredProjection> = HashMap::new();
652 let mut umap_graph_cache: HashMap<usize, crate::umap::UmapGraph> = HashMap::new();
659 let mut umap_graph_builds: usize = 0;
660 let mut trials: Vec<TrialRecord> = Vec::new();
661 let mut failures: Vec<(PipelineConfig, String)> = Vec::new();
662 let mut best: Option<(f64, SphereQLPipeline)> = None;
668
669 let run_trial = |cfg: PipelineConfig,
673 prefit: &mut HashMap<ProjectionFitKey, ConfiguredProjection>,
674 umap_graph_cache: &mut HashMap<usize, crate::umap::UmapGraph>,
675 umap_graph_builds: &mut usize,
676 trials: &mut Vec<TrialRecord>,
677 failures: &mut Vec<(PipelineConfig, String)>,
678 best: &mut Option<(f64, SphereQLPipeline)>| {
679 let key = ProjectionFitKey::from_config(&cfg);
680 let projection = if cfg.projection_kind == ProjectionKind::UmapSphere {
681 match prefit.get(&key) {
688 Some(p) => p.clone(),
689 None => {
690 let k = cfg.umap.n_neighbors;
691 if let std::collections::hash_map::Entry::Vacant(entry) =
692 umap_graph_cache.entry(k)
693 {
694 match crate::umap::UmapGraph::build(&embeddings, k) {
695 Ok(g) => {
696 entry.insert(g);
697 *umap_graph_builds += 1;
698 }
699 Err(err) => {
700 failures.push((cfg, err.to_string()));
701 return;
702 }
703 }
704 }
705 let graph = &umap_graph_cache[&k];
706 match fit_umap_from_graph(graph, &categories, &cfg) {
707 Ok(p) => {
708 prefit.insert(key, p.clone());
709 p
710 }
711 Err(e) => {
712 failures.push((cfg, e.to_string()));
713 return;
714 }
715 }
716 }
717 }
718 } else {
719 match prefit.get(&key) {
720 Some(p) => p.clone(),
721 None => match fit_projection_for_config(&embeddings, &categories, &cfg) {
722 Ok(p) => {
723 prefit.insert(key, p.clone());
724 p
725 }
726 Err(e) => {
727 failures.push((cfg, e.to_string()));
728 return;
729 }
730 },
731 }
732 };
733
734 let start = Instant::now();
735 match SphereQLPipeline::with_projection_parts(
741 categories.clone(),
742 &embeddings,
743 projection,
744 cfg.clone(),
745 ) {
746 Ok(pipeline) => {
747 let (score, components) = metric.score_with_components(&pipeline);
748 let build_ms = start.elapsed().as_millis();
749 trials.push(TrialRecord {
750 config: cfg,
751 score,
752 build_ms,
753 components,
754 });
755 let replace = match best {
756 Some((best_score, _)) => !matches!(
757 score.partial_cmp(best_score),
758 Some(std::cmp::Ordering::Less)
759 ),
760 None => true,
761 };
762 if replace {
763 *best = Some((score, pipeline));
764 }
765 }
766 Err(e) => {
767 failures.push((cfg, e.to_string()));
768 }
769 }
770 };
771
772 let wall_start = Instant::now();
773 let max_wall = strategy.max_wall_secs();
774 let wall_exceeded = || match max_wall {
775 Some(max_secs) => wall_start.elapsed().as_secs() >= max_secs,
776 None => false,
777 };
778
779 match &strategy {
780 SearchStrategy::Grid => {
781 for i in 0..space.grid_cardinality() {
785 if let Some(cfg) = space.config_at_index(i, base_config) {
786 run_trial(
787 cfg,
788 &mut prefit,
789 &mut umap_graph_cache,
790 &mut umap_graph_builds,
791 &mut trials,
792 &mut failures,
793 &mut best,
794 );
795 }
796 }
797 }
798 SearchStrategy::Random { budget, seed, .. } => {
799 let mut rng = SplitMix64::new(*seed);
800 run_trial(
803 base_config.clone(),
804 &mut prefit,
805 &mut umap_graph_cache,
806 &mut umap_graph_builds,
807 &mut trials,
808 &mut failures,
809 &mut best,
810 );
811 if !wall_exceeded() {
812 for _ in 1..*budget {
813 let cfg = space.sample(&mut rng, base_config);
814 run_trial(
815 cfg,
816 &mut prefit,
817 &mut umap_graph_cache,
818 &mut umap_graph_builds,
819 &mut trials,
820 &mut failures,
821 &mut best,
822 );
823 if wall_exceeded() {
824 break;
825 }
826 }
827 }
828 }
829 SearchStrategy::Bayesian {
830 budget,
831 warmup,
832 gamma,
833 seed,
834 ..
835 } => {
836 let budget = *budget;
838 let mut rng = SplitMix64::new(*seed);
839 let warmup = (*warmup).clamp(2, budget);
840 let gamma = gamma.clamp(0.05, 0.95);
841
842 run_trial(
844 base_config.clone(),
845 &mut prefit,
846 &mut umap_graph_cache,
847 &mut umap_graph_builds,
848 &mut trials,
849 &mut failures,
850 &mut best,
851 );
852 if !wall_exceeded() {
854 for _ in 1..warmup {
855 let cfg = space.sample(&mut rng, base_config);
856 run_trial(
857 cfg,
858 &mut prefit,
859 &mut umap_graph_cache,
860 &mut umap_graph_builds,
861 &mut trials,
862 &mut failures,
863 &mut best,
864 );
865 if wall_exceeded() {
866 break;
867 }
868 }
869 }
870 if !wall_exceeded() {
872 for _ in warmup..budget {
873 let cfg = tpe_propose(space, base_config, &trials, gamma, &mut rng);
874 run_trial(
875 cfg,
876 &mut prefit,
877 &mut umap_graph_cache,
878 &mut umap_graph_builds,
879 &mut trials,
880 &mut failures,
881 &mut best,
882 );
883 if wall_exceeded() {
884 break;
885 }
886 }
887 }
888 }
889 }
890
891 if trials.is_empty() {
892 return Err(PipelineError::AllTrialsFailed { failures });
896 }
897
898 let (best_score, best_pipeline) = best.expect("non-empty trials imply a kept best pipeline");
903 let best_config = best_pipeline.config().clone();
904
905 let report = TuneReport {
906 metric_name: metric.name().to_string(),
907 best_score,
908 best_config,
909 trials,
910 failures,
911 umap_graph_builds,
912 };
913
914 Ok((best_pipeline, report))
915}
916
917fn tpe_propose(
932 space: &SearchSpace,
933 base: &PipelineConfig,
934 trials: &[TrialRecord],
935 gamma: f64,
936 rng: &mut SplitMix64,
937) -> PipelineConfig {
938 let mut sorted: Vec<&TrialRecord> = trials.iter().collect();
940 sorted.sort_by(|a, b| {
941 b.score
942 .partial_cmp(&a.score)
943 .unwrap_or(std::cmp::Ordering::Equal)
944 });
945 let n_good = ((sorted.len() as f64) * gamma).ceil() as usize;
946 let n_good = n_good.max(1).min(sorted.len().saturating_sub(1).max(1));
947 let good: Vec<&TrialRecord> = sorted.iter().take(n_good).copied().collect();
948 let bad: Vec<&TrialRecord> = sorted.iter().skip(n_good).copied().collect();
949
950 if good.is_empty() || bad.is_empty() {
952 return space.sample(rng, base);
953 }
954
955 let pick_idx = |rng: &mut SplitMix64, good_counts: &[f64], bad_counts: &[f64]| -> usize {
956 let n_g = good_counts.iter().sum::<f64>() + good_counts.len() as f64;
957 let n_b = bad_counts.iter().sum::<f64>() + bad_counts.len() as f64;
958 let weights: Vec<f64> = good_counts
959 .iter()
960 .zip(bad_counts.iter())
961 .map(|(&g, &b)| ((g + 1.0) / n_g) / ((b + 1.0) / n_b))
962 .collect();
963 sample_categorical(rng, &weights)
964 };
965
966 let pk_g = hist_kind(&good, &space.projection_kinds);
968 let pk_b = hist_kind(&bad, &space.projection_kinds);
969 let kind = space.projection_kinds[pick_idx(rng, &pk_g, &pk_b)];
970
971 let ndg_g = hist_usize(&good, &space.num_domain_groups, |c| {
977 c.routing.num_domain_groups
978 });
979 let ndg_b = hist_usize(&bad, &space.num_domain_groups, |c| {
980 c.routing.num_domain_groups
981 });
982 let let_g = hist_f64(&good, &space.low_evr_threshold, |c| {
983 c.routing.low_evr_threshold
984 });
985 let let_b = hist_f64(&bad, &space.low_evr_threshold, |c| {
986 c.routing.low_evr_threshold
987 });
988 let oat_g = hist_f64(&good, &space.overlap_artifact_territorial, |c| {
989 c.bridges.overlap_artifact_territorial
990 });
991 let oat_b = hist_f64(&bad, &space.overlap_artifact_territorial, |c| {
992 c.bridges.overlap_artifact_territorial
993 });
994 let tb_g = hist_f64(&good, &space.threshold_base, |c| c.bridges.threshold_base);
995 let tb_b = hist_f64(&bad, &space.threshold_base, |c| c.bridges.threshold_base);
996 let tep_g = hist_f64(&good, &space.threshold_evr_penalty, |c| {
997 c.bridges.threshold_evr_penalty
998 });
999 let tep_b = hist_f64(&bad, &space.threshold_evr_penalty, |c| {
1000 c.bridges.threshold_evr_penalty
1001 });
1002 let mei_g = hist_f64(&good, &space.min_evr_improvement, |c| {
1003 c.inner_sphere.min_evr_improvement
1004 });
1005 let mei_b = hist_f64(&bad, &space.min_evr_improvement, |c| {
1006 c.inner_sphere.min_evr_improvement
1007 });
1008
1009 let mut cfg = base.clone();
1010 cfg.projection_kind = kind;
1011 cfg.routing = RoutingConfig {
1012 num_domain_groups: space.num_domain_groups[pick_idx(rng, &ndg_g, &ndg_b)],
1013 low_evr_threshold: space.low_evr_threshold[pick_idx(rng, &let_g, &let_b)],
1014 ..base.routing.clone()
1015 };
1016 cfg.bridges = BridgeConfig {
1017 threshold_base: space.threshold_base[pick_idx(rng, &tb_g, &tb_b)],
1018 threshold_evr_penalty: space.threshold_evr_penalty[pick_idx(rng, &tep_g, &tep_b)],
1019 overlap_artifact_territorial: space.overlap_artifact_territorial
1020 [pick_idx(rng, &oat_g, &oat_b)],
1021 ..base.bridges.clone()
1022 };
1023 cfg.inner_sphere = InnerSphereConfig {
1024 min_evr_improvement: space.min_evr_improvement[pick_idx(rng, &mei_g, &mei_b)],
1025 ..base.inner_sphere.clone()
1026 };
1027
1028 if matches!(kind, ProjectionKind::LaplacianEigenmap) {
1030 let good_l: Vec<&TrialRecord> = good
1031 .iter()
1032 .copied()
1033 .filter(|t| t.config.projection_kind == ProjectionKind::LaplacianEigenmap)
1034 .collect();
1035 let bad_l: Vec<&TrialRecord> = bad
1036 .iter()
1037 .copied()
1038 .filter(|t| t.config.projection_kind == ProjectionKind::LaplacianEigenmap)
1039 .collect();
1040 if good_l.is_empty() || bad_l.is_empty() {
1041 cfg.laplacian = LaplacianConfig {
1043 k_neighbors: pick_uniform(rng, &space.laplacian_k_neighbors),
1044 active_threshold: pick_uniform(rng, &space.laplacian_active_threshold),
1045 };
1046 } else {
1047 let k_g = hist_usize(&good_l, &space.laplacian_k_neighbors, |c| {
1048 c.laplacian.k_neighbors
1049 });
1050 let k_b = hist_usize(&bad_l, &space.laplacian_k_neighbors, |c| {
1051 c.laplacian.k_neighbors
1052 });
1053 let at_g = hist_f64(&good_l, &space.laplacian_active_threshold, |c| {
1054 c.laplacian.active_threshold
1055 });
1056 let at_b = hist_f64(&bad_l, &space.laplacian_active_threshold, |c| {
1057 c.laplacian.active_threshold
1058 });
1059 cfg.laplacian = LaplacianConfig {
1060 k_neighbors: space.laplacian_k_neighbors[pick_idx(rng, &k_g, &k_b)],
1061 active_threshold: space.laplacian_active_threshold[pick_idx(rng, &at_g, &at_b)],
1062 };
1063 }
1064 }
1065
1066 if matches!(kind, ProjectionKind::UmapSphere) {
1067 let good_u: Vec<&TrialRecord> = good
1068 .iter()
1069 .copied()
1070 .filter(|t| t.config.projection_kind == ProjectionKind::UmapSphere)
1071 .collect();
1072 let bad_u: Vec<&TrialRecord> = bad
1073 .iter()
1074 .copied()
1075 .filter(|t| t.config.projection_kind == ProjectionKind::UmapSphere)
1076 .collect();
1077 if good_u.is_empty() || bad_u.is_empty() {
1078 cfg.umap = UmapConfig {
1079 n_neighbors: pick_uniform(rng, &space.umap_n_neighbors),
1080 n_epochs: pick_uniform(rng, &space.umap_n_epochs),
1081 category_weight: pick_uniform(rng, &space.umap_category_weight),
1082 min_dist: pick_uniform(rng, &space.umap_min_dist),
1083 ..base.umap.clone()
1084 };
1085 } else {
1086 let nn_g = hist_usize(&good_u, &space.umap_n_neighbors, |c| c.umap.n_neighbors);
1087 let nn_b = hist_usize(&bad_u, &space.umap_n_neighbors, |c| c.umap.n_neighbors);
1088 let ne_g = hist_usize(&good_u, &space.umap_n_epochs, |c| c.umap.n_epochs);
1089 let ne_b = hist_usize(&bad_u, &space.umap_n_epochs, |c| c.umap.n_epochs);
1090 let cw_g = hist_f64(&good_u, &space.umap_category_weight, |c| {
1091 c.umap.category_weight
1092 });
1093 let cw_b = hist_f64(&bad_u, &space.umap_category_weight, |c| {
1094 c.umap.category_weight
1095 });
1096 let md_g = hist_f64(&good_u, &space.umap_min_dist, |c| c.umap.min_dist);
1097 let md_b = hist_f64(&bad_u, &space.umap_min_dist, |c| c.umap.min_dist);
1098 cfg.umap = UmapConfig {
1099 n_neighbors: space.umap_n_neighbors[pick_idx(rng, &nn_g, &nn_b)],
1100 n_epochs: space.umap_n_epochs[pick_idx(rng, &ne_g, &ne_b)],
1101 category_weight: space.umap_category_weight[pick_idx(rng, &cw_g, &cw_b)],
1102 min_dist: space.umap_min_dist[pick_idx(rng, &md_g, &md_b)],
1103 ..base.umap.clone()
1104 };
1105 }
1106 }
1107
1108 cfg
1109}
1110
1111fn hist_kind(trials: &[&TrialRecord], values: &[ProjectionKind]) -> Vec<f64> {
1112 let mut counts = vec![0.0f64; values.len()];
1113 for t in trials {
1114 if let Some(i) = values.iter().position(|&v| v == t.config.projection_kind) {
1115 counts[i] += 1.0;
1116 }
1117 }
1118 counts
1119}
1120
1121fn hist_usize(
1122 trials: &[&TrialRecord],
1123 values: &[usize],
1124 extract: impl Fn(&PipelineConfig) -> usize,
1125) -> Vec<f64> {
1126 let mut counts = vec![0.0f64; values.len()];
1127 for t in trials {
1128 let v = extract(&t.config);
1129 if let Some(i) = values.iter().position(|&x| x == v) {
1130 counts[i] += 1.0;
1131 }
1132 }
1133 counts
1134}
1135
1136fn hist_f64(
1141 trials: &[&TrialRecord],
1142 values: &[f64],
1143 extract: impl Fn(&PipelineConfig) -> f64,
1144) -> Vec<f64> {
1145 let mut counts = vec![0.0f64; values.len()];
1146 for t in trials {
1147 let v = extract(&t.config);
1148 if let Some((i, _)) = values.iter().enumerate().min_by(|a, b| {
1149 (a.1 - v)
1150 .abs()
1151 .partial_cmp(&(b.1 - v).abs())
1152 .unwrap_or(std::cmp::Ordering::Equal)
1153 }) {
1154 counts[i] += 1.0;
1155 }
1156 }
1157 counts
1158}
1159
1160fn pick_uniform<T: Copy>(rng: &mut SplitMix64, vals: &[T]) -> T {
1165 vals[((rng.next_f64() * vals.len() as f64) as usize).min(vals.len() - 1)]
1169}
1170
1171fn sample_categorical(rng: &mut SplitMix64, weights: &[f64]) -> usize {
1172 let total: f64 = weights.iter().sum();
1173 if total <= 0.0 || !total.is_finite() {
1174 let n = weights.len().max(1);
1175 return ((rng.next_f64() * n as f64) as usize).min(n - 1);
1176 }
1177 let r = rng.next_f64() * total;
1178 let mut acc = 0.0;
1179 for (i, &w) in weights.iter().enumerate() {
1180 acc += w;
1181 if r <= acc {
1182 return i;
1183 }
1184 }
1185 weights.len() - 1
1186}
1187
1188#[cfg(test)]
1191mod tests {
1192 use super::*;
1193 use crate::quality_metric::{BridgeCoherence, CompositeMetric, TerritorialHealth};
1194
1195 fn make_input(n: usize, dim: usize) -> PipelineInput {
1196 let mut embeddings = Vec::new();
1197 let mut categories = Vec::new();
1198 for i in 0..n {
1199 let mut v = vec![0.0; dim];
1200 if i < n / 3 {
1201 v[0] = 1.0 + (i as f64 * 0.01);
1202 v[1] = 0.1;
1203 categories.push("one".into());
1204 } else if i < 2 * n / 3 {
1205 v[2] = 1.0 + (i as f64 * 0.01);
1206 v[3] = 0.1;
1207 categories.push("two".into());
1208 } else {
1209 v[4] = 1.0 + (i as f64 * 0.01);
1210 v[5] = 0.1;
1211 categories.push("three".into());
1212 }
1213 v[6] = 0.02 * i as f64;
1214 embeddings.push(v);
1215 }
1216 PipelineInput {
1217 categories,
1218 embeddings,
1219 }
1220 }
1221
1222 fn full_search_space() -> SearchSpace {
1223 SearchSpace {
1224 projection_kinds: vec![ProjectionKind::Pca],
1225 laplacian_k_neighbors: vec![15],
1226 laplacian_active_threshold: vec![0.05],
1227 umap_n_neighbors: vec![15],
1228 umap_n_epochs: vec![200],
1229 umap_category_weight: vec![1.5],
1230 umap_min_dist: vec![0.1],
1231 num_domain_groups: vec![3],
1232 low_evr_threshold: vec![0.3],
1233 overlap_artifact_territorial: vec![0.3],
1234 threshold_base: vec![0.5],
1235 threshold_evr_penalty: vec![0.4],
1236 min_evr_improvement: vec![0.10],
1237 }
1238 }
1239
1240 #[test]
1241 fn validate_rejects_empty_projection_kinds_for_every_strategy() {
1242 let mut s = full_search_space();
1243 s.projection_kinds.clear();
1244 for strategy in [
1245 SearchStrategy::Grid,
1246 SearchStrategy::Random {
1247 budget: 4,
1248 seed: 1,
1249 max_wall_secs: None,
1250 },
1251 SearchStrategy::Bayesian {
1252 budget: 4,
1253 warmup: 2,
1254 gamma: 0.25,
1255 seed: 1,
1256 max_wall_secs: None,
1257 },
1258 ] {
1259 match s.validate(&strategy) {
1260 Err(PipelineError::InvalidSearchSpace(msg)) => {
1261 assert!(msg.contains("projection_kinds"), "msg = {msg:?}");
1262 }
1263 other => panic!("expected InvalidSearchSpace, got {other:?}"),
1264 }
1265 }
1266 }
1267
1268 #[test]
1269 fn validate_rejects_empty_axis() {
1270 let mut s = full_search_space();
1271 s.threshold_base.clear();
1272 match s.validate(&SearchStrategy::Grid) {
1273 Err(PipelineError::InvalidSearchSpace(msg)) => {
1274 assert!(msg.contains("threshold_base"), "msg = {msg:?}");
1275 }
1276 other => panic!("expected InvalidSearchSpace, got {other:?}"),
1277 }
1278 }
1279
1280 #[test]
1281 fn validate_rejects_empty_laplacian_axis_only_when_kind_present() {
1282 let mut s = full_search_space();
1283 s.laplacian_k_neighbors.clear();
1284 assert!(s.validate(&SearchStrategy::Grid).is_ok());
1287 s.projection_kinds.push(ProjectionKind::LaplacianEigenmap);
1288 match s.validate(&SearchStrategy::Grid) {
1289 Err(PipelineError::InvalidSearchSpace(msg)) => {
1290 assert!(msg.contains("laplacian_k_neighbors"), "msg = {msg:?}");
1291 }
1292 other => panic!("expected InvalidSearchSpace, got {other:?}"),
1293 }
1294 }
1295
1296 #[test]
1297 fn validate_rejects_bad_bayesian_params() {
1298 let s = full_search_space();
1299 let cases: &[(SearchStrategy, &str)] = &[
1300 (
1301 SearchStrategy::Bayesian {
1302 budget: 1,
1303 warmup: 2,
1304 gamma: 0.25,
1305 seed: 1,
1306 max_wall_secs: None,
1307 },
1308 "budget",
1309 ),
1310 (
1311 SearchStrategy::Bayesian {
1312 budget: 5,
1313 warmup: 1,
1314 gamma: 0.25,
1315 seed: 1,
1316 max_wall_secs: None,
1317 },
1318 "warmup",
1319 ),
1320 (
1321 SearchStrategy::Bayesian {
1322 budget: 5,
1323 warmup: 2,
1324 gamma: 0.0,
1325 seed: 1,
1326 max_wall_secs: None,
1327 },
1328 "gamma",
1329 ),
1330 (
1331 SearchStrategy::Bayesian {
1332 budget: 5,
1333 warmup: 2,
1334 gamma: f64::NAN,
1335 seed: 1,
1336 max_wall_secs: None,
1337 },
1338 "gamma",
1339 ),
1340 ];
1341 for (strategy, needle) in cases {
1342 match s.validate(strategy) {
1343 Err(PipelineError::InvalidSearchSpace(msg)) => {
1344 assert!(msg.contains(needle), "msg={msg:?} needle={needle:?}");
1345 }
1346 other => panic!("expected InvalidSearchSpace for {needle:?}, got {other:?}"),
1347 }
1348 }
1349 }
1350
1351 #[test]
1352 fn auto_tune_propagates_invalid_search_space_for_grid() {
1353 let s = SearchSpace {
1354 projection_kinds: vec![],
1355 ..full_search_space()
1356 };
1357 let metric = BridgeCoherence;
1358 let base = PipelineConfig::default();
1359 match auto_tune(make_input(30, 10), &s, &metric, SearchStrategy::Grid, &base) {
1360 Err(PipelineError::InvalidSearchSpace(_)) => {}
1361 Err(other) => panic!("expected InvalidSearchSpace, got {other:?}"),
1362 Ok(_) => panic!("expected error, got Ok"),
1363 }
1364 }
1365
1366 #[test]
1367 fn search_space_grid_cardinality_sums_per_kind() {
1368 let s = SearchSpace::default();
1369 let common = s.num_domain_groups.len()
1370 * s.low_evr_threshold.len()
1371 * s.overlap_artifact_territorial.len()
1372 * s.threshold_base.len()
1373 * s.threshold_evr_penalty.len()
1374 * s.min_evr_improvement.len();
1375 let expected =
1378 common + common * s.laplacian_k_neighbors.len() * s.laplacian_active_threshold.len();
1379 assert_eq!(s.grid_cardinality(), expected);
1380 }
1381
1382 #[test]
1383 fn default_search_space_includes_pca_and_laplacian() {
1384 let s = SearchSpace::default();
1385 assert!(s.projection_kinds.contains(&ProjectionKind::Pca));
1386 assert!(
1387 s.projection_kinds
1388 .contains(&ProjectionKind::LaplacianEigenmap)
1389 );
1390 assert!(!s.projection_kinds.contains(&ProjectionKind::KernelPca));
1392 }
1393
1394 #[test]
1395 fn grid_index_enumerates_full_space() {
1396 let s = SearchSpace {
1397 projection_kinds: vec![ProjectionKind::Pca],
1398 laplacian_k_neighbors: vec![15],
1399 laplacian_active_threshold: vec![0.05],
1400 umap_n_neighbors: vec![15],
1401 umap_n_epochs: vec![200],
1402 umap_category_weight: vec![1.5],
1403 umap_min_dist: vec![0.1],
1404 num_domain_groups: vec![3, 5],
1405 low_evr_threshold: vec![0.3, 0.4],
1406 overlap_artifact_territorial: vec![0.3],
1407 threshold_base: vec![0.5],
1408 threshold_evr_penalty: vec![0.4],
1409 min_evr_improvement: vec![0.10],
1410 };
1411 let base = PipelineConfig::default();
1412 let n = s.grid_cardinality();
1413 let mut seen = std::collections::HashSet::new();
1414 for i in 0..n {
1415 let cfg = s.config_at_index(i, &base).unwrap();
1416 let key = (
1417 cfg.routing.num_domain_groups,
1418 (cfg.routing.low_evr_threshold * 1000.0) as i64,
1419 );
1420 seen.insert(key);
1421 }
1422 assert_eq!(seen.len(), n);
1423 assert!(s.config_at_index(n, &base).is_none());
1424 }
1425
1426 #[test]
1427 fn grid_index_enumerates_across_projection_kinds() {
1428 let s = SearchSpace {
1429 projection_kinds: vec![ProjectionKind::Pca, ProjectionKind::LaplacianEigenmap],
1430 laplacian_k_neighbors: vec![15],
1431 laplacian_active_threshold: vec![0.05],
1432 umap_n_neighbors: vec![15],
1433 umap_n_epochs: vec![200],
1434 umap_category_weight: vec![1.5],
1435 umap_min_dist: vec![0.1],
1436 num_domain_groups: vec![3],
1437 low_evr_threshold: vec![0.35],
1438 overlap_artifact_territorial: vec![0.3],
1439 threshold_base: vec![0.5],
1440 threshold_evr_penalty: vec![0.4],
1441 min_evr_improvement: vec![0.10],
1442 };
1443 let base = PipelineConfig::default();
1444 let kinds: std::collections::HashSet<ProjectionKind> = (0..s.grid_cardinality())
1445 .map(|i| s.config_at_index(i, &base).unwrap().projection_kind)
1446 .collect();
1447 assert_eq!(kinds.len(), 2);
1448 assert!(kinds.contains(&ProjectionKind::Pca));
1449 assert!(kinds.contains(&ProjectionKind::LaplacianEigenmap));
1450 }
1451
1452 #[test]
1453 fn grid_search_runs_and_picks_best() {
1454 let input = make_input(24, 8);
1455 let space = SearchSpace {
1456 projection_kinds: vec![ProjectionKind::Pca],
1457 laplacian_k_neighbors: vec![15],
1458 laplacian_active_threshold: vec![0.05],
1459 umap_n_neighbors: vec![15],
1460 umap_n_epochs: vec![200],
1461 umap_category_weight: vec![1.5],
1462 umap_min_dist: vec![0.1],
1463 num_domain_groups: vec![3, 5],
1464 low_evr_threshold: vec![0.35],
1465 overlap_artifact_territorial: vec![0.3],
1466 threshold_base: vec![0.5],
1467 threshold_evr_penalty: vec![0.4],
1468 min_evr_improvement: vec![0.10],
1469 };
1470 let metric = TerritorialHealth;
1471 let (pipeline, report) = auto_tune(
1472 input,
1473 &space,
1474 &metric,
1475 SearchStrategy::Grid,
1476 &PipelineConfig::default(),
1477 )
1478 .unwrap();
1479
1480 assert_eq!(report.trials.len(), 2);
1481 assert!(report.best_score >= report.mean_score() - 1e-9);
1482 assert!(pipeline.num_categories() > 0);
1483 assert_eq!(report.metric_name, "territorial_health");
1484 assert!(report.failures.is_empty());
1485 }
1486
1487 #[test]
1488 fn trial_records_carry_component_breakdown_for_composites() {
1489 let input = make_input(24, 8);
1490 let metric = CompositeMetric::default_composite();
1491 let (_p, report) = auto_tune(
1492 input,
1493 &full_search_space(),
1494 &metric,
1495 SearchStrategy::Grid,
1496 &PipelineConfig::default(),
1497 )
1498 .unwrap();
1499 assert!(!report.trials.is_empty());
1500 for t in &report.trials {
1501 assert_eq!(
1502 t.components.len(),
1503 4,
1504 "composite trials must record the 4-component breakdown"
1505 );
1506 let recomposed: f64 = t.components.iter().map(|(_, w, s)| w * s).sum();
1507 assert!(
1508 (t.score - recomposed).abs() < 1e-12,
1509 "breakdown must recompose to the recorded score"
1510 );
1511 }
1512 }
1513
1514 #[test]
1515 fn trial_records_have_empty_components_for_leaf_metrics() {
1516 let input = make_input(24, 8);
1517 let metric = TerritorialHealth;
1518 let (_p, report) = auto_tune(
1519 input,
1520 &full_search_space(),
1521 &metric,
1522 SearchStrategy::Grid,
1523 &PipelineConfig::default(),
1524 )
1525 .unwrap();
1526 assert!(!report.trials.is_empty());
1527 for t in &report.trials {
1528 assert!(t.components.is_empty());
1529 }
1530 }
1531
1532 #[test]
1533 fn random_search_respects_budget() {
1534 let input = make_input(24, 8);
1535 let space = SearchSpace::default();
1536 let metric = BridgeCoherence;
1537 let (_pipeline, report) = auto_tune(
1538 input,
1539 &space,
1540 &metric,
1541 SearchStrategy::Random {
1542 budget: 5,
1543 seed: 42,
1544 max_wall_secs: None,
1545 },
1546 &PipelineConfig::default(),
1547 )
1548 .unwrap();
1549 assert_eq!(report.trials.len(), 5);
1550 }
1551
1552 #[test]
1553 fn random_search_respects_wall_time_cap() {
1554 let input = make_input(24, 8);
1555 let space = SearchSpace::default();
1556 let metric = TerritorialHealth;
1557 let (_pipeline, report) = auto_tune(
1558 input,
1559 &space,
1560 &metric,
1561 SearchStrategy::Random {
1562 budget: 1000,
1563 seed: 42,
1569 max_wall_secs: Some(0),
1570 },
1571 &PipelineConfig::default(),
1572 )
1573 .unwrap();
1574 assert!(
1575 report.trials.len() < 1000,
1576 "wall time cap should have stopped early, got {} trials",
1577 report.trials.len()
1578 );
1579 assert!(
1580 !report.trials.is_empty(),
1581 "should complete at least one trial before checking wall time"
1582 );
1583 }
1584
1585 #[test]
1586 fn none_wall_time_is_unlimited() {
1587 let input = make_input(24, 8);
1588 let space = full_search_space();
1589 let metric = TerritorialHealth;
1590 let (_pipeline, report) = auto_tune(
1591 input,
1592 &space,
1593 &metric,
1594 SearchStrategy::Random {
1595 budget: 3,
1596 seed: 1,
1597 max_wall_secs: None,
1598 },
1599 &PipelineConfig::default(),
1600 )
1601 .unwrap();
1602 assert_eq!(report.trials.len(), 3);
1603 }
1604
1605 #[test]
1606 fn random_search_is_seed_reproducible() {
1607 let space = SearchSpace::default();
1608 let metric = TerritorialHealth;
1609
1610 let run = |seed: u64| {
1611 let input = make_input(24, 8);
1612 auto_tune(
1613 input,
1614 &space,
1615 &metric,
1616 SearchStrategy::Random {
1617 budget: 8,
1618 seed,
1619 max_wall_secs: None,
1620 },
1621 &PipelineConfig::default(),
1622 )
1623 .unwrap()
1624 .1
1625 };
1626
1627 let a = run(7);
1628 let b = run(7);
1629 let c = run(13);
1630
1631 assert_eq!(a.trials.len(), b.trials.len());
1632 for (ta, tb) in a.trials.iter().zip(b.trials.iter()) {
1633 assert_eq!(
1634 ta.config.routing.num_domain_groups,
1635 tb.config.routing.num_domain_groups
1636 );
1637 assert!((ta.score - tb.score).abs() < 1e-12);
1638 }
1639 let any_differ = a.trials.iter().zip(c.trials.iter()).any(|(ta, tc)| {
1643 ta.config.routing.num_domain_groups != tc.config.routing.num_domain_groups
1644 || (ta.config.bridges.threshold_base - tc.config.bridges.threshold_base).abs()
1645 > 1e-12
1646 });
1647 assert!(any_differ, "different seeds produced identical trial set");
1648 }
1649
1650 #[test]
1651 fn ranked_trials_are_descending() {
1652 let input = make_input(24, 8);
1653 let metric = CompositeMetric::default_composite();
1654 let (_p, report) = auto_tune(
1655 input,
1656 &SearchSpace::default(),
1657 &metric,
1658 SearchStrategy::Random {
1659 budget: 6,
1660 seed: 99,
1661 max_wall_secs: None,
1662 },
1663 &PipelineConfig::default(),
1664 )
1665 .unwrap();
1666 let ranked = report.ranked_trials();
1667 for w in ranked.windows(2) {
1668 assert!(w[0].score >= w[1].score);
1669 }
1670 }
1671
1672 #[test]
1673 fn best_config_actually_in_trials() {
1674 let input = make_input(24, 8);
1675 let metric = TerritorialHealth;
1676 let (_p, report) = auto_tune(
1677 input,
1678 &SearchSpace::default(),
1679 &metric,
1680 SearchStrategy::Random {
1681 budget: 4,
1682 seed: 1,
1683 max_wall_secs: None,
1684 },
1685 &PipelineConfig::default(),
1686 )
1687 .unwrap();
1688 let any_match = report.trials.iter().any(|t| {
1689 t.config.routing.num_domain_groups == report.best_config.routing.num_domain_groups
1690 && (t.config.routing.low_evr_threshold
1691 - report.best_config.routing.low_evr_threshold)
1692 .abs()
1693 < 1e-12
1694 && (t.score - report.best_score).abs() < 1e-12
1695 });
1696 assert!(any_match, "best_config must appear in trials");
1697 }
1698
1699 #[test]
1700 fn grid_search_across_projection_kinds_yields_both() {
1701 let input = make_input(24, 8);
1702 let space = SearchSpace {
1703 projection_kinds: vec![ProjectionKind::Pca, ProjectionKind::LaplacianEigenmap],
1704 laplacian_k_neighbors: vec![10, 20],
1705 laplacian_active_threshold: vec![0.05],
1706 umap_n_neighbors: vec![15],
1707 umap_n_epochs: vec![200],
1708 umap_category_weight: vec![1.5],
1709 umap_min_dist: vec![0.1],
1710 num_domain_groups: vec![3],
1711 low_evr_threshold: vec![0.35],
1712 overlap_artifact_territorial: vec![0.3],
1713 threshold_base: vec![0.5],
1714 threshold_evr_penalty: vec![0.4],
1715 min_evr_improvement: vec![0.10],
1716 };
1717 let metric = TerritorialHealth;
1718 let (_pipeline, report) = auto_tune(
1719 input,
1720 &space,
1721 &metric,
1722 SearchStrategy::Grid,
1723 &PipelineConfig::default(),
1724 )
1725 .unwrap();
1726 assert_eq!(report.trials.len(), 3);
1729 let kinds_in_trials: std::collections::HashSet<ProjectionKind> = report
1730 .trials
1731 .iter()
1732 .map(|t| t.config.projection_kind)
1733 .collect();
1734 assert!(kinds_in_trials.contains(&ProjectionKind::Pca));
1735 assert!(kinds_in_trials.contains(&ProjectionKind::LaplacianEigenmap));
1736 let lap_ks: std::collections::HashSet<usize> = report
1738 .trials
1739 .iter()
1740 .filter(|t| t.config.projection_kind == ProjectionKind::LaplacianEigenmap)
1741 .map(|t| t.config.laplacian.k_neighbors)
1742 .collect();
1743 assert_eq!(lap_ks.len(), 2);
1744 }
1745
1746 #[test]
1747 fn laplacian_knobs_produce_distinct_configs() {
1748 let s = SearchSpace {
1752 projection_kinds: vec![ProjectionKind::LaplacianEigenmap],
1753 laplacian_k_neighbors: vec![10, 20],
1754 laplacian_active_threshold: vec![0.03, 0.08],
1755 umap_n_neighbors: vec![15],
1756 umap_n_epochs: vec![200],
1757 umap_category_weight: vec![1.5],
1758 umap_min_dist: vec![0.1],
1759 num_domain_groups: vec![3],
1760 low_evr_threshold: vec![0.35],
1761 overlap_artifact_territorial: vec![0.3],
1762 threshold_base: vec![0.5],
1763 threshold_evr_penalty: vec![0.4],
1764 min_evr_improvement: vec![0.10],
1765 };
1766 let base = PipelineConfig::default();
1767 let configs: Vec<(usize, u64)> = (0..s.grid_cardinality())
1768 .map(|i| {
1769 let cfg = s.config_at_index(i, &base).unwrap();
1770 (
1771 cfg.laplacian.k_neighbors,
1772 cfg.laplacian.active_threshold.to_bits(),
1773 )
1774 })
1775 .collect();
1776 let unique: std::collections::HashSet<(usize, u64)> = configs.iter().copied().collect();
1777 assert_eq!(unique.len(), 4, "expected 4 distinct (k, threshold) pairs");
1778 }
1779
1780 #[test]
1781 fn bayesian_respects_budget() {
1782 let input = make_input(24, 8);
1783 let metric = TerritorialHealth;
1784 let (_p, report) = auto_tune(
1785 input,
1786 &SearchSpace::default(),
1787 &metric,
1788 SearchStrategy::Bayesian {
1789 budget: 10,
1790 warmup: 4,
1791 gamma: 0.25,
1792 seed: 42,
1793 max_wall_secs: None,
1794 },
1795 &PipelineConfig::default(),
1796 )
1797 .unwrap();
1798 assert_eq!(report.trials.len(), 10);
1799 }
1800
1801 #[test]
1802 fn bayesian_seed_reproducible() {
1803 let metric = TerritorialHealth;
1804 let run = |seed: u64| {
1805 let input = make_input(24, 8);
1806 auto_tune(
1807 input,
1808 &SearchSpace::default(),
1809 &metric,
1810 SearchStrategy::Bayesian {
1811 budget: 8,
1812 warmup: 3,
1813 gamma: 0.25,
1814 seed,
1815 max_wall_secs: None,
1816 },
1817 &PipelineConfig::default(),
1818 )
1819 .unwrap()
1820 .1
1821 };
1822 let a = run(7);
1823 let b = run(7);
1824 assert_eq!(a.trials.len(), b.trials.len());
1825 for (ta, tb) in a.trials.iter().zip(b.trials.iter()) {
1826 assert_eq!(ta.config.projection_kind, tb.config.projection_kind);
1827 assert!((ta.score - tb.score).abs() < 1e-12);
1828 }
1829 }
1830
1831 #[test]
1832 fn bayesian_finds_something_under_default_metric() {
1833 let input = make_input(30, 10);
1837 let metric = CompositeMetric::default_composite();
1838 let (_p, report) = auto_tune(
1839 input,
1840 &SearchSpace::default(),
1841 &metric,
1842 SearchStrategy::Bayesian {
1843 budget: 12,
1844 warmup: 4,
1845 gamma: 0.25,
1846 seed: 0xC0FFEE,
1847 max_wall_secs: None,
1848 },
1849 &PipelineConfig::default(),
1850 )
1851 .unwrap();
1852 assert_eq!(report.trials.len(), 12);
1853 assert!(report.best_score >= 0.0 && report.best_score <= 1.0);
1854 }
1855
1856 #[test]
1857 fn bayesian_warmup_clamped() {
1858 let input = make_input(24, 8);
1860 let metric = TerritorialHealth;
1861 let (_p, report) = auto_tune(
1862 input,
1863 &SearchSpace::default(),
1864 &metric,
1865 SearchStrategy::Bayesian {
1866 budget: 5,
1867 warmup: 100,
1868 gamma: 0.25,
1869 seed: 1,
1870 max_wall_secs: None,
1871 },
1872 &PipelineConfig::default(),
1873 )
1874 .unwrap();
1875 assert_eq!(report.trials.len(), 5);
1876 }
1877
1878 #[test]
1879 fn umap_search_space_cardinality() {
1880 let s = SearchSpace::large_corpus();
1881 let common = s.num_domain_groups.len()
1882 * s.low_evr_threshold.len()
1883 * s.overlap_artifact_territorial.len()
1884 * s.threshold_base.len()
1885 * s.threshold_evr_penalty.len()
1886 * s.min_evr_improvement.len();
1887 let umap_specific = s.umap_n_neighbors.len()
1888 * s.umap_n_epochs.len()
1889 * s.umap_category_weight.len()
1890 * s.umap_min_dist.len();
1891 let expected = common + common * umap_specific;
1893 assert_eq!(s.grid_cardinality(), expected);
1894 }
1895
1896 #[test]
1897 fn umap_trials_produce_umap_configs() {
1898 let input = make_input(24, 8);
1899 let space = SearchSpace {
1900 projection_kinds: vec![ProjectionKind::UmapSphere],
1901 laplacian_k_neighbors: vec![15],
1902 laplacian_active_threshold: vec![0.05],
1903 umap_n_neighbors: vec![10, 20],
1904 umap_n_epochs: vec![50],
1905 umap_category_weight: vec![1.0],
1906 umap_min_dist: vec![0.1],
1907 num_domain_groups: vec![3],
1908 low_evr_threshold: vec![0.35],
1909 overlap_artifact_territorial: vec![0.3],
1910 threshold_base: vec![0.5],
1911 threshold_evr_penalty: vec![0.4],
1912 min_evr_improvement: vec![0.10],
1913 };
1914 let metric = TerritorialHealth;
1915 let (_pipeline, report) = auto_tune(
1916 input,
1917 &space,
1918 &metric,
1919 SearchStrategy::Grid,
1920 &PipelineConfig::default(),
1921 )
1922 .unwrap();
1923
1924 assert_eq!(report.trials.len(), 2);
1925 for t in &report.trials {
1926 assert_eq!(t.config.projection_kind, ProjectionKind::UmapSphere);
1927 }
1928 let nn_values: std::collections::HashSet<usize> = report
1929 .trials
1930 .iter()
1931 .map(|t| t.config.umap.n_neighbors)
1932 .collect();
1933 assert_eq!(nn_values.len(), 2);
1934 }
1935
1936 #[test]
1937 fn umap_graph_cache_reuses_across_trials_sharing_n_neighbors() {
1938 let input = make_input(24, 8);
1944 let space = SearchSpace {
1945 projection_kinds: vec![ProjectionKind::UmapSphere],
1946 laplacian_k_neighbors: vec![15],
1947 laplacian_active_threshold: vec![0.05],
1948 umap_n_neighbors: vec![10],
1949 umap_n_epochs: vec![30, 60],
1950 umap_category_weight: vec![0.0, 1.0, 2.0],
1951 umap_min_dist: vec![0.1],
1952 num_domain_groups: vec![3],
1953 low_evr_threshold: vec![0.35],
1954 overlap_artifact_territorial: vec![0.3],
1955 threshold_base: vec![0.5],
1956 threshold_evr_penalty: vec![0.4],
1957 min_evr_improvement: vec![0.10],
1958 };
1959 let metric = TerritorialHealth;
1960 let (_pipeline, report) = auto_tune(
1961 input,
1962 &space,
1963 &metric,
1964 SearchStrategy::Grid,
1965 &PipelineConfig::default(),
1966 )
1967 .unwrap();
1968
1969 assert_eq!(report.trials.len(), 6, "6 UMAP configs in the grid");
1970 assert_eq!(
1971 report.umap_graph_builds, 1,
1972 "all 6 configs share n_neighbors=10, so the cache should build the graph exactly once"
1973 );
1974 }
1975
1976 #[test]
1977 fn umap_graph_cache_builds_one_per_unique_n_neighbors() {
1978 let input = make_input(24, 8);
1982 let space = SearchSpace {
1983 projection_kinds: vec![ProjectionKind::UmapSphere],
1984 laplacian_k_neighbors: vec![15],
1985 laplacian_active_threshold: vec![0.05],
1986 umap_n_neighbors: vec![10, 20],
1987 umap_n_epochs: vec![30, 60],
1988 umap_category_weight: vec![0.0],
1989 umap_min_dist: vec![0.1],
1990 num_domain_groups: vec![3],
1991 low_evr_threshold: vec![0.35],
1992 overlap_artifact_territorial: vec![0.3],
1993 threshold_base: vec![0.5],
1994 threshold_evr_penalty: vec![0.4],
1995 min_evr_improvement: vec![0.10],
1996 };
1997 let metric = TerritorialHealth;
1998 let (_pipeline, report) = auto_tune(
1999 input,
2000 &space,
2001 &metric,
2002 SearchStrategy::Grid,
2003 &PipelineConfig::default(),
2004 )
2005 .unwrap();
2006
2007 assert_eq!(report.trials.len(), 4);
2008 assert_eq!(
2009 report.umap_graph_builds, 2,
2010 "n_neighbors ∈ {{10, 20}} should produce exactly 2 graph builds"
2011 );
2012 }
2013
2014 #[test]
2015 fn umap_fit_key_distinguishes_min_dist() {
2016 let a = PipelineConfig {
2021 projection_kind: ProjectionKind::UmapSphere,
2022 ..PipelineConfig::default()
2023 };
2024 let mut b = a.clone();
2025 b.umap.min_dist = 0.5;
2026 assert!(ProjectionFitKey::from_config(&a) == ProjectionFitKey::from_config(&a.clone()));
2027 assert!(ProjectionFitKey::from_config(&a) != ProjectionFitKey::from_config(&b));
2028 }
2029
2030 #[test]
2031 fn umap_graph_cache_zero_when_no_umap_trials() {
2032 let input = make_input(24, 8);
2034 let space = SearchSpace {
2035 projection_kinds: vec![ProjectionKind::Pca],
2036 laplacian_k_neighbors: vec![15],
2037 laplacian_active_threshold: vec![0.05],
2038 umap_n_neighbors: vec![10],
2039 umap_n_epochs: vec![30],
2040 umap_category_weight: vec![0.0],
2041 umap_min_dist: vec![0.1],
2042 num_domain_groups: vec![3],
2043 low_evr_threshold: vec![0.35],
2044 overlap_artifact_territorial: vec![0.3],
2045 threshold_base: vec![0.5],
2046 threshold_evr_penalty: vec![0.4],
2047 min_evr_improvement: vec![0.10],
2048 };
2049 let metric = TerritorialHealth;
2050 let (_pipeline, report) = auto_tune(
2051 input,
2052 &space,
2053 &metric,
2054 SearchStrategy::Grid,
2055 &PipelineConfig::default(),
2056 )
2057 .unwrap();
2058
2059 assert_eq!(report.umap_graph_builds, 0);
2060 }
2061
2062 #[test]
2063 fn validate_rejects_empty_umap_axis_only_when_kind_present() {
2064 let mut s = full_search_space();
2065 s.umap_n_neighbors.clear();
2066 assert!(s.validate(&SearchStrategy::Grid).is_ok());
2068 s.projection_kinds.push(ProjectionKind::UmapSphere);
2069 match s.validate(&SearchStrategy::Grid) {
2070 Err(PipelineError::InvalidSearchSpace(msg)) => {
2071 assert!(msg.contains("umap_n_neighbors"), "msg = {msg:?}");
2072 }
2073 other => panic!("expected InvalidSearchSpace, got {other:?}"),
2074 }
2075 }
2076
2077 #[test]
2078 fn tpe_proposes_dominating_value_more_often_than_uniform() {
2079 let space = SearchSpace {
2084 projection_kinds: vec![ProjectionKind::Pca],
2085 laplacian_k_neighbors: vec![15],
2086 laplacian_active_threshold: vec![0.05],
2087 umap_n_neighbors: vec![15],
2088 umap_n_epochs: vec![200],
2089 umap_category_weight: vec![1.5],
2090 umap_min_dist: vec![0.1],
2091 num_domain_groups: vec![3, 5, 7],
2092 low_evr_threshold: vec![0.3],
2093 overlap_artifact_territorial: vec![0.3],
2094 threshold_base: vec![0.5],
2095 threshold_evr_penalty: vec![0.4],
2096 min_evr_improvement: vec![0.10],
2097 };
2098 let base = PipelineConfig::default();
2099
2100 let trial = |ndg: usize, score: f64| -> TrialRecord {
2101 let mut config = base.clone();
2102 config.projection_kind = ProjectionKind::Pca;
2103 config.routing.num_domain_groups = ndg;
2104 TrialRecord {
2105 config,
2106 score,
2107 build_ms: 0,
2108 components: Vec::new(),
2109 }
2110 };
2111
2112 let mut trials = Vec::new();
2113 for i in 0..4 {
2114 trials.push(trial(7, 0.9 + i as f64 * 0.01));
2115 }
2116 for i in 0..6 {
2117 trials.push(trial(3, 0.1 + i as f64 * 0.01));
2118 trials.push(trial(5, 0.1 + i as f64 * 0.005));
2119 }
2120
2121 let mut rng = SplitMix64::new(42);
2122 let n_proposals = 300;
2123 let mut count_7 = 0usize;
2124 for _ in 0..n_proposals {
2125 let cfg = tpe_propose(&space, &base, &trials, 0.25, &mut rng);
2126 if cfg.routing.num_domain_groups == 7 {
2127 count_7 += 1;
2128 }
2129 }
2130
2131 assert!(
2136 count_7 > 180,
2137 "dominating value proposed only {count_7}/{n_proposals} times (uniform ≈ {})",
2138 n_proposals / 3
2139 );
2140 }
2141
2142 #[test]
2143 fn random_seeds_base_config_as_trial_zero() {
2144 let input = make_input(24, 8);
2145 let mut base = PipelineConfig::default();
2146 base.bridges.overlap_artifact_territorial = 0.123; let metric = TerritorialHealth;
2148 let (_p, report) = auto_tune(
2149 input,
2150 &full_search_space(),
2151 &metric,
2152 SearchStrategy::Random {
2153 budget: 4,
2154 seed: 9,
2155 max_wall_secs: None,
2156 },
2157 &base,
2158 )
2159 .unwrap();
2160
2161 assert_eq!(report.trials.len(), 4, "seed trial counts against budget");
2162 assert!(
2163 (report.trials[0].config.bridges.overlap_artifact_territorial - 0.123).abs() < 1e-12,
2164 "trial 0 must be base_config itself"
2165 );
2166 for t in &report.trials[1..] {
2167 assert!(
2168 (t.config.bridges.overlap_artifact_territorial - 0.3).abs() < 1e-12,
2169 "sampled trials must come from the space's axes"
2170 );
2171 }
2172 }
2173
2174 #[test]
2175 fn bayesian_seeds_base_config_as_trial_zero() {
2176 let input = make_input(24, 8);
2177 let mut base = PipelineConfig::default();
2178 base.bridges.overlap_artifact_territorial = 0.123;
2179 let metric = TerritorialHealth;
2180 let (_p, report) = auto_tune(
2181 input,
2182 &full_search_space(),
2183 &metric,
2184 SearchStrategy::Bayesian {
2185 budget: 5,
2186 warmup: 2,
2187 gamma: 0.25,
2188 seed: 9,
2189 max_wall_secs: None,
2190 },
2191 &base,
2192 )
2193 .unwrap();
2194
2195 assert_eq!(report.trials.len(), 5);
2196 assert!(
2197 (report.trials[0].config.bridges.overlap_artifact_territorial - 0.123).abs() < 1e-12
2198 );
2199 }
2200
2201 #[test]
2202 fn grid_does_not_seed_base_config() {
2203 let input = make_input(24, 8);
2204 let mut base = PipelineConfig::default();
2205 base.bridges.overlap_artifact_territorial = 0.123;
2206 let metric = TerritorialHealth;
2207 let (_p, report) = auto_tune(
2208 input,
2209 &full_search_space(),
2210 &metric,
2211 SearchStrategy::Grid,
2212 &base,
2213 )
2214 .unwrap();
2215
2216 assert_eq!(
2217 report.trials.len(),
2218 full_search_space().grid_cardinality(),
2219 "grid trial count must stay the exact enumeration"
2220 );
2221 for t in &report.trials {
2222 assert!((t.config.bridges.overlap_artifact_territorial - 0.3).abs() < 1e-12);
2223 }
2224 }
2225
2226 #[test]
2227 fn returned_pipeline_uses_best_config() {
2228 let input = make_input(24, 8);
2229 let metric = TerritorialHealth;
2230 let (pipeline, report) = auto_tune(
2231 input,
2232 &SearchSpace::default(),
2233 &metric,
2234 SearchStrategy::Random {
2235 budget: 4,
2236 seed: 11,
2237 max_wall_secs: None,
2238 },
2239 &PipelineConfig::default(),
2240 )
2241 .unwrap();
2242 assert_eq!(
2243 pipeline.config().routing.num_domain_groups,
2244 report.best_config.routing.num_domain_groups
2245 );
2246 assert_eq!(
2247 pipeline.projection_kind(),
2248 report.best_config.projection_kind
2249 );
2250 }
2251}