1use std::collections::VecDeque;
48use std::sync::RwLock;
49use std::sync::atomic::{AtomicU64, Ordering};
50use std::time::{Duration, Instant};
51
52#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
54pub enum StageQuantLevel {
55 BPS,
57 PQ,
59 I8,
61 F32,
63}
64
65impl StageQuantLevel {
66 pub const fn relative_cost(self) -> f32 {
68 match self {
69 StageQuantLevel::BPS => 0.05,
70 StageQuantLevel::PQ => 0.10,
71 StageQuantLevel::I8 => 0.25,
72 StageQuantLevel::F32 => 1.00,
73 }
74 }
75
76 pub const fn expected_recall(self) -> f32 {
78 match self {
79 StageQuantLevel::BPS => 0.70,
80 StageQuantLevel::PQ => 0.90,
81 StageQuantLevel::I8 => 0.995,
82 StageQuantLevel::F32 => 1.00,
83 }
84 }
85}
86
87#[derive(Debug, Clone)]
89pub struct PipelineStage {
90 pub quant_level: StageQuantLevel,
92 pub input_candidates: usize,
94 pub output_candidates: usize,
96 pub apply_filter: bool,
98}
99
100impl PipelineStage {
101 pub fn estimate_cost(&self, dimension: usize, cost_model: &CostModel) -> f32 {
103 let base_cost =
104 self.input_candidates as f32 * dimension as f32 * self.quant_level.relative_cost();
105 base_cost * cost_model.cpu_cycles_per_op
106 }
107
108 pub fn estimate_recall(&self, total_vectors: usize) -> f32 {
110 let coverage = (self.input_candidates as f32 / total_vectors as f32).min(1.0);
111 self.quant_level.expected_recall() * coverage.sqrt()
112 }
113}
114
115#[derive(Debug, Clone)]
117pub struct SearchSLA {
118 pub target_recall: f32,
120 pub latency_budget: Duration,
122 pub token_budget: Option<u64>,
124 pub mode: OptimizationMode,
126}
127
128impl Default for SearchSLA {
129 fn default() -> Self {
130 Self {
131 target_recall: 0.95,
132 latency_budget: Duration::from_millis(10),
133 token_budget: None,
134 mode: OptimizationMode::Balanced,
135 }
136 }
137}
138
139#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
141pub enum OptimizationMode {
142 Speed,
144 Quality,
146 #[default]
148 Balanced,
149 SLO,
151}
152
153#[derive(Debug, Clone)]
155pub struct CostModel {
156 pub cpu_cycles_per_op: f32,
158 pub memory_bandwidth_gbps: f32,
160 pub l3_cache_bytes: usize,
162 pub stage_costs_ns: StageCosts,
164}
165
166#[derive(Debug, Clone, Default)]
168pub struct StageCosts {
169 pub bps_per_candidate_ns: f32,
171 pub pq_per_candidate_ns: f32,
173 pub i8_per_candidate_ns: f32,
175 pub f32_per_candidate_ns: f32,
177}
178
179impl Default for CostModel {
180 fn default() -> Self {
181 Self {
182 cpu_cycles_per_op: 1.0,
183 memory_bandwidth_gbps: 50.0,
184 l3_cache_bytes: 32 * 1024 * 1024, stage_costs_ns: StageCosts {
186 bps_per_candidate_ns: 10.0,
187 pq_per_candidate_ns: 50.0,
188 i8_per_candidate_ns: 100.0,
189 f32_per_candidate_ns: 500.0,
190 },
191 }
192 }
193}
194
195#[derive(Debug, Clone)]
197pub struct DatasetStats {
198 pub total_vectors: usize,
200 pub dimension: usize,
202 pub available_levels: Vec<StageQuantLevel>,
204 pub filter_selectivity: Option<f32>,
206 pub recent_latencies: Option<(Duration, Duration, Duration)>,
208}
209
210impl Default for DatasetStats {
211 fn default() -> Self {
212 Self {
213 total_vectors: 0,
214 dimension: 0,
215 available_levels: vec![StageQuantLevel::F32],
216 filter_selectivity: None,
217 recent_latencies: None,
218 }
219 }
220}
221
222#[derive(Debug, Clone)]
224pub struct SearchPlan {
225 pub stages: Vec<PipelineStage>,
227 pub ef_search: usize,
229 pub k: usize,
231 pub use_batched_expansion: bool,
233 pub prefetch_distance: usize,
235 pub estimated_latency: Duration,
237 pub estimated_recall: f32,
239 pub created_at: Instant,
241}
242
243impl SearchPlan {
244 pub fn simple(k: usize, ef_search: usize) -> Self {
246 Self {
247 stages: vec![PipelineStage {
248 quant_level: StageQuantLevel::F32,
249 input_candidates: ef_search,
250 output_candidates: k,
251 apply_filter: false,
252 }],
253 ef_search,
254 k,
255 use_batched_expansion: true,
256 prefetch_distance: 4,
257 estimated_latency: Duration::from_millis(1),
258 estimated_recall: 0.95,
259 created_at: Instant::now(),
260 }
261 }
262
263 pub fn multi_stage(k: usize, total_vectors: usize, target_recall: f32) -> Self {
265 let coarse_candidates = (total_vectors as f32 * 0.1).min(10000.0) as usize;
267 let refine_candidates = (coarse_candidates as f32 * 0.1).max(k as f32 * 10.0) as usize;
268 let _rerank_candidates = (refine_candidates as f32 * 0.5).max(k as f32 * 2.0) as usize;
269
270 let mut stages = Vec::new();
271
272 if total_vectors > 10_000 {
274 stages.push(PipelineStage {
275 quant_level: StageQuantLevel::BPS,
276 input_candidates: total_vectors,
277 output_candidates: coarse_candidates,
278 apply_filter: true, });
280 }
281
282 if total_vectors > 1_000 {
284 stages.push(PipelineStage {
285 quant_level: StageQuantLevel::PQ,
286 input_candidates: coarse_candidates,
287 output_candidates: refine_candidates,
288 apply_filter: false,
289 });
290 }
291
292 let rerank_level = if target_recall > 0.99 {
294 StageQuantLevel::F32
295 } else {
296 StageQuantLevel::I8
297 };
298
299 stages.push(PipelineStage {
300 quant_level: rerank_level,
301 input_candidates: refine_candidates,
302 output_candidates: k,
303 apply_filter: false,
304 });
305
306 Self {
307 stages,
308 ef_search: refine_candidates.max(64),
309 k,
310 use_batched_expansion: true,
311 prefetch_distance: 4,
312 estimated_latency: Duration::from_millis(5),
313 estimated_recall: target_recall,
314 created_at: Instant::now(),
315 }
316 }
317
318 pub fn total_cost(&self, dimension: usize, cost_model: &CostModel) -> f32 {
320 self.stages
321 .iter()
322 .map(|s| s.estimate_cost(dimension, cost_model))
323 .sum()
324 }
325
326 pub fn meets_sla(&self, sla: &SearchSLA) -> bool {
328 self.estimated_recall >= sla.target_recall && self.estimated_latency <= sla.latency_budget
329 }
330}
331
332pub struct SearchPlanner {
334 cost_model: CostModel,
336 recent_stats: RwLock<RecentStats>,
338 query_count: AtomicU64,
340}
341
342#[derive(Debug, Default)]
344struct RecentStats {
345 latencies: VecDeque<Duration>,
347 recalls: VecDeque<f32>,
349 window_size: usize,
351}
352
353impl RecentStats {
354 fn new(window_size: usize) -> Self {
355 Self {
356 latencies: VecDeque::with_capacity(window_size),
357 recalls: VecDeque::with_capacity(window_size),
358 window_size,
359 }
360 }
361
362 fn record(&mut self, latency: Duration, recall: f32) {
363 if self.latencies.len() >= self.window_size {
364 self.latencies.pop_front();
365 self.recalls.pop_front();
366 }
367 self.latencies.push_back(latency);
368 self.recalls.push_back(recall);
369 }
370
371 fn avg_latency(&self) -> Option<Duration> {
372 if self.latencies.is_empty() {
373 return None;
374 }
375 let sum: Duration = self.latencies.iter().sum();
376 Some(sum / self.latencies.len() as u32)
377 }
378
379 #[allow(dead_code)]
380 fn avg_recall(&self) -> Option<f32> {
381 if self.recalls.is_empty() {
382 return None;
383 }
384 Some(self.recalls.iter().sum::<f32>() / self.recalls.len() as f32)
385 }
386}
387
388impl SearchPlanner {
389 pub fn new() -> Self {
391 Self {
392 cost_model: CostModel::default(),
393 recent_stats: RwLock::new(RecentStats::new(100)),
394 query_count: AtomicU64::new(0),
395 }
396 }
397
398 pub fn with_cost_model(cost_model: CostModel) -> Self {
400 Self {
401 cost_model,
402 recent_stats: RwLock::new(RecentStats::new(100)),
403 query_count: AtomicU64::new(0),
404 }
405 }
406
407 pub fn plan(&self, k: usize, sla: &SearchSLA, stats: &DatasetStats) -> SearchPlan {
409 self.query_count.fetch_add(1, Ordering::Relaxed);
410
411 match sla.mode {
413 OptimizationMode::Speed => self.plan_for_speed(k, sla, stats),
414 OptimizationMode::Quality => self.plan_for_quality(k, sla, stats),
415 OptimizationMode::Balanced => self.plan_balanced(k, sla, stats),
416 OptimizationMode::SLO => self.plan_for_slo(k, sla, stats),
417 }
418 }
419
420 fn plan_for_speed(&self, k: usize, _sla: &SearchSLA, stats: &DatasetStats) -> SearchPlan {
422 let ef = k.max(16);
424
425 if stats.total_vectors > 100_000 && stats.available_levels.contains(&StageQuantLevel::BPS) {
426 let coarse_count = (stats.total_vectors as f32 * 0.01).max(1000.0) as usize;
428
429 SearchPlan {
430 stages: vec![
431 PipelineStage {
432 quant_level: StageQuantLevel::BPS,
433 input_candidates: stats.total_vectors,
434 output_candidates: coarse_count,
435 apply_filter: true,
436 },
437 PipelineStage {
438 quant_level: StageQuantLevel::I8,
439 input_candidates: coarse_count,
440 output_candidates: k,
441 apply_filter: false,
442 },
443 ],
444 ef_search: ef,
445 k,
446 use_batched_expansion: true,
447 prefetch_distance: 8,
448 estimated_latency: Duration::from_micros(500),
449 estimated_recall: 0.85,
450 created_at: Instant::now(),
451 }
452 } else {
453 let level = if stats.available_levels.contains(&StageQuantLevel::I8) {
455 StageQuantLevel::I8
456 } else {
457 StageQuantLevel::F32
458 };
459
460 SearchPlan {
461 stages: vec![PipelineStage {
462 quant_level: level,
463 input_candidates: ef * 4,
464 output_candidates: k,
465 apply_filter: true,
466 }],
467 ef_search: ef,
468 k,
469 use_batched_expansion: true,
470 prefetch_distance: 4,
471 estimated_latency: Duration::from_millis(1),
472 estimated_recall: 0.90,
473 created_at: Instant::now(),
474 }
475 }
476 }
477
478 fn plan_for_quality(&self, k: usize, sla: &SearchSLA, stats: &DatasetStats) -> SearchPlan {
480 let ef = (k * 10).max(200);
482
483 SearchPlan {
484 stages: vec![PipelineStage {
485 quant_level: StageQuantLevel::F32,
486 input_candidates: ef,
487 output_candidates: k,
488 apply_filter: false, }],
490 ef_search: ef,
491 k,
492 use_batched_expansion: true,
493 prefetch_distance: 4,
494 estimated_latency: self.estimate_latency(ef, stats.dimension, StageQuantLevel::F32),
495 estimated_recall: sla.target_recall.min(0.99),
496 created_at: Instant::now(),
497 }
498 }
499
500 fn plan_balanced(&self, k: usize, sla: &SearchSLA, stats: &DatasetStats) -> SearchPlan {
502 let ef = (k * 4).max(64);
504
505 let use_pq =
507 stats.total_vectors > 10_000 && stats.available_levels.contains(&StageQuantLevel::PQ);
508 let use_i8 = stats.available_levels.contains(&StageQuantLevel::I8);
509
510 let mut stages = Vec::new();
511
512 if use_pq {
513 stages.push(PipelineStage {
514 quant_level: StageQuantLevel::PQ,
515 input_candidates: ef * 10,
516 output_candidates: ef * 2,
517 apply_filter: true,
518 });
519 }
520
521 let final_level = if sla.target_recall > 0.98 {
522 StageQuantLevel::F32
523 } else if use_i8 {
524 StageQuantLevel::I8
525 } else {
526 StageQuantLevel::F32
527 };
528
529 stages.push(PipelineStage {
530 quant_level: final_level,
531 input_candidates: if use_pq { ef * 2 } else { ef * 4 },
532 output_candidates: k,
533 apply_filter: !use_pq,
534 });
535
536 let estimated_recall = self.estimate_pipeline_recall(&stages, stats.total_vectors);
537 let estimated_latency = self.estimate_pipeline_latency(&stages, stats.dimension);
538
539 SearchPlan {
540 stages,
541 ef_search: ef,
542 k,
543 use_batched_expansion: true,
544 prefetch_distance: 4,
545 estimated_latency,
546 estimated_recall,
547 created_at: Instant::now(),
548 }
549 }
550
551 fn plan_for_slo(&self, k: usize, sla: &SearchSLA, stats: &DatasetStats) -> SearchPlan {
553 let recent = self.recent_stats.read().unwrap();
555
556 let base_plan = if let Some(avg_latency) = recent.avg_latency() {
557 if avg_latency > sla.latency_budget {
559 self.plan_for_speed(k, sla, stats)
561 } else if avg_latency < sla.latency_budget / 2 {
562 self.plan_for_quality(k, sla, stats)
564 } else {
565 self.plan_balanced(k, sla, stats)
566 }
567 } else {
568 self.plan_balanced(k, sla, stats)
570 };
571
572 if base_plan.estimated_latency > sla.latency_budget {
574 self.plan_for_speed(k, sla, stats)
576 } else {
577 base_plan
578 }
579 }
580
581 pub fn record_feedback(&self, latency: Duration, recall: f32) {
583 let mut stats = self.recent_stats.write().unwrap();
584 stats.record(latency, recall);
585 }
586
587 fn estimate_latency(
589 &self,
590 candidates: usize,
591 dimension: usize,
592 level: StageQuantLevel,
593 ) -> Duration {
594 let cost_per_candidate = match level {
595 StageQuantLevel::BPS => self.cost_model.stage_costs_ns.bps_per_candidate_ns,
596 StageQuantLevel::PQ => self.cost_model.stage_costs_ns.pq_per_candidate_ns,
597 StageQuantLevel::I8 => self.cost_model.stage_costs_ns.i8_per_candidate_ns,
598 StageQuantLevel::F32 => self.cost_model.stage_costs_ns.f32_per_candidate_ns,
599 };
600
601 let total_ns = candidates as f32 * cost_per_candidate * (dimension as f32 / 128.0);
602 Duration::from_nanos(total_ns as u64)
603 }
604
605 fn estimate_pipeline_recall(&self, stages: &[PipelineStage], total_vectors: usize) -> f32 {
607 stages
608 .iter()
609 .fold(1.0, |acc, stage| acc * stage.estimate_recall(total_vectors))
610 }
611
612 fn estimate_pipeline_latency(&self, stages: &[PipelineStage], dimension: usize) -> Duration {
614 stages
615 .iter()
616 .map(|stage| {
617 self.estimate_latency(stage.input_candidates, dimension, stage.quant_level)
618 })
619 .sum()
620 }
621
622 pub fn cost_model(&self) -> &CostModel {
624 &self.cost_model
625 }
626
627 pub fn query_count(&self) -> u64 {
629 self.query_count.load(Ordering::Relaxed)
630 }
631}
632
633impl Default for SearchPlanner {
634 fn default() -> Self {
635 Self::new()
636 }
637}
638
639pub struct PlanExecutor;
641
642impl PlanExecutor {
643 pub fn validate(plan: &SearchPlan) -> Result<(), PlanError> {
645 if plan.stages.is_empty() {
646 return Err(PlanError::EmptyPipeline);
647 }
648
649 if plan.k == 0 {
650 return Err(PlanError::InvalidK);
651 }
652
653 for window in plan.stages.windows(2) {
655 if window[0].output_candidates < window[1].input_candidates {
656 if window[0].output_candidates * 2 < window[1].input_candidates {
658 return Err(PlanError::StageOutputMismatch);
659 }
660 }
661 }
662
663 Ok(())
664 }
665}
666
667#[derive(Debug, Clone, PartialEq, Eq)]
669pub enum PlanError {
670 EmptyPipeline,
672 InvalidK,
674 StageOutputMismatch,
676}
677
678impl std::fmt::Display for PlanError {
679 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
680 match self {
681 PlanError::EmptyPipeline => write!(f, "Pipeline has no stages"),
682 PlanError::InvalidK => write!(f, "k must be greater than 0"),
683 PlanError::StageOutputMismatch => {
684 write!(f, "Stage output doesn't match next stage input")
685 }
686 }
687 }
688}
689
690impl std::error::Error for PlanError {}
691
692#[cfg(test)]
693mod tests {
694 use super::*;
695
696 #[test]
697 fn test_simple_plan() {
698 let plan = SearchPlan::simple(10, 64);
699
700 assert_eq!(plan.k, 10);
701 assert_eq!(plan.ef_search, 64);
702 assert_eq!(plan.stages.len(), 1);
703 assert!(PlanExecutor::validate(&plan).is_ok());
704 }
705
706 #[test]
707 fn test_multi_stage_plan() {
708 let plan = SearchPlan::multi_stage(10, 1_000_000, 0.95);
709
710 assert_eq!(plan.k, 10);
711 assert!(plan.stages.len() >= 2);
712 assert!(PlanExecutor::validate(&plan).is_ok());
713 }
714
715 #[test]
716 fn test_planner_speed_mode() {
717 let planner = SearchPlanner::new();
718 let sla = SearchSLA {
719 mode: OptimizationMode::Speed,
720 ..Default::default()
721 };
722 let stats = DatasetStats {
723 total_vectors: 1_000_000,
724 dimension: 768,
725 available_levels: vec![
726 StageQuantLevel::BPS,
727 StageQuantLevel::I8,
728 StageQuantLevel::F32,
729 ],
730 ..Default::default()
731 };
732
733 let plan = planner.plan(10, &sla, &stats);
734
735 assert!(
737 plan.stages
738 .iter()
739 .any(|s| s.quant_level == StageQuantLevel::BPS)
740 );
741 assert!(PlanExecutor::validate(&plan).is_ok());
742 }
743
744 #[test]
745 fn test_planner_quality_mode() {
746 let planner = SearchPlanner::new();
747 let sla = SearchSLA {
748 mode: OptimizationMode::Quality,
749 target_recall: 0.99,
750 ..Default::default()
751 };
752 let stats = DatasetStats {
753 total_vectors: 100_000,
754 dimension: 768,
755 available_levels: vec![StageQuantLevel::F32],
756 ..Default::default()
757 };
758
759 let plan = planner.plan(10, &sla, &stats);
760
761 assert!(
763 plan.stages
764 .iter()
765 .any(|s| s.quant_level == StageQuantLevel::F32)
766 );
767 assert!(plan.ef_search >= 100);
768 }
769
770 #[test]
771 fn test_planner_balanced_mode() {
772 let planner = SearchPlanner::new();
773 let sla = SearchSLA {
774 mode: OptimizationMode::Balanced,
775 target_recall: 0.95,
776 ..Default::default()
777 };
778 let stats = DatasetStats {
779 total_vectors: 100_000,
780 dimension: 384,
781 available_levels: vec![
782 StageQuantLevel::PQ,
783 StageQuantLevel::I8,
784 StageQuantLevel::F32,
785 ],
786 ..Default::default()
787 };
788
789 let plan = planner.plan(10, &sla, &stats);
790
791 assert!(plan.stages.len() >= 1);
792 assert!(PlanExecutor::validate(&plan).is_ok());
793 }
794
795 #[test]
796 fn test_feedback_adaptation() {
797 let planner = SearchPlanner::new();
798
799 for _ in 0..10 {
801 planner.record_feedback(Duration::from_micros(100), 0.98);
802 }
803
804 let sla = SearchSLA {
805 mode: OptimizationMode::SLO,
806 latency_budget: Duration::from_millis(5),
807 ..Default::default()
808 };
809 let stats = DatasetStats {
810 total_vectors: 100_000,
811 dimension: 384,
812 available_levels: vec![StageQuantLevel::F32],
813 ..Default::default()
814 };
815
816 let plan = planner.plan(10, &sla, &stats);
818 assert!(plan.ef_search >= 64);
819 }
820
821 #[test]
822 fn test_plan_cost_estimation() {
823 let plan = SearchPlan::simple(10, 64);
824 let cost_model = CostModel::default();
825
826 let cost = plan.total_cost(384, &cost_model);
827 assert!(cost > 0.0);
828 }
829
830 #[test]
831 fn test_plan_meets_sla() {
832 let plan = SearchPlan {
833 stages: vec![],
834 ef_search: 64,
835 k: 10,
836 use_batched_expansion: true,
837 prefetch_distance: 4,
838 estimated_latency: Duration::from_millis(2),
839 estimated_recall: 0.96,
840 created_at: Instant::now(),
841 };
842
843 let sla = SearchSLA {
844 target_recall: 0.95,
845 latency_budget: Duration::from_millis(5),
846 ..Default::default()
847 };
848
849 assert!(plan.meets_sla(&sla));
850
851 let strict_sla = SearchSLA {
852 target_recall: 0.99,
853 latency_budget: Duration::from_millis(1),
854 ..Default::default()
855 };
856
857 assert!(!plan.meets_sla(&strict_sla));
858 }
859
860 #[test]
861 fn test_invalid_plan() {
862 let empty_plan = SearchPlan {
863 stages: vec![],
864 ef_search: 64,
865 k: 10,
866 use_batched_expansion: true,
867 prefetch_distance: 4,
868 estimated_latency: Duration::from_millis(1),
869 estimated_recall: 0.95,
870 created_at: Instant::now(),
871 };
872
873 assert_eq!(
874 PlanExecutor::validate(&empty_plan),
875 Err(PlanError::EmptyPipeline)
876 );
877
878 let zero_k_plan = SearchPlan {
879 stages: vec![PipelineStage {
880 quant_level: StageQuantLevel::F32,
881 input_candidates: 64,
882 output_candidates: 0,
883 apply_filter: false,
884 }],
885 ef_search: 64,
886 k: 0,
887 use_batched_expansion: true,
888 prefetch_distance: 4,
889 estimated_latency: Duration::from_millis(1),
890 estimated_recall: 0.95,
891 created_at: Instant::now(),
892 };
893
894 assert_eq!(
895 PlanExecutor::validate(&zero_k_plan),
896 Err(PlanError::InvalidK)
897 );
898 }
899
900 #[test]
901 fn test_stage_relative_costs() {
902 assert!(StageQuantLevel::BPS.relative_cost() < StageQuantLevel::PQ.relative_cost());
903 assert!(StageQuantLevel::PQ.relative_cost() < StageQuantLevel::I8.relative_cost());
904 assert!(StageQuantLevel::I8.relative_cost() < StageQuantLevel::F32.relative_cost());
905 }
906}