Skip to main content

sochdb_vector/
search_plan.rs

1// Copyright 2025 SochDB Authors
2//
3// Licensed under the Apache License, Version 2.0
4
5//! Quantization-aware Search Plan
6//!
7//! This module provides a formal runtime plan for vector search that separates
8//! policy (what to optimize for) from mechanism (how to execute).
9//!
10//! # Architecture
11//!
12//! ```text
13//! SearchRequest + SLA → Planner → SearchPlan → Executor → Results
14//!                          ↑
15//!                    Cost Model + Statistics
16//! ```
17//!
18//! # Policy vs Mechanism
19//!
20//! **Policy** (what to optimize):
21//! - Target recall@k (e.g., 0.95)
22//! - Latency budget (e.g., 5ms p99)
23//! - Token/compute budget
24//!
25//! **Mechanism** (how to execute):
26//! - BPS coarse scan parameters
27//! - PQ scoring parameters
28//! - Rerank depth and method
29//! - ef_search value
30//! - Filter evaluation order
31//!
32//! # Cost Model
33//!
34//! The planner uses measured per-stage costs:
35//! - `cost_bps(N, D)` = N × D × c_bps
36//! - `cost_pq(N, D, M)` = N × M × c_pq
37//! - `cost_rerank(N, D)` = N × D × c_f32
38//!
39//! # Optimization
40//!
41//! Minimize expected latency subject to:
42//! - recall@k ≥ target_recall
43//! - total_cost ≤ budget
44//!
45//! Uses bandit-like adaptation based on recent query statistics.
46
47use std::collections::VecDeque;
48use std::sync::RwLock;
49use std::sync::atomic::{AtomicU64, Ordering};
50use std::time::{Duration, Instant};
51
52/// Quantization level for a pipeline stage.
53#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
54pub enum StageQuantLevel {
55    /// Block Projection Sketch (coarse filtering)
56    BPS,
57    /// Product Quantization
58    PQ,
59    /// 8-bit integer
60    I8,
61    /// Full precision f32
62    F32,
63}
64
65impl StageQuantLevel {
66    /// Relative cost per vector-dimension operation.
67    pub const fn relative_cost(self) -> f32 {
68        match self {
69            StageQuantLevel::BPS => 0.05,
70            StageQuantLevel::PQ => 0.10,
71            StageQuantLevel::I8 => 0.25,
72            StageQuantLevel::F32 => 1.00,
73        }
74    }
75
76    /// Expected recall at this level.
77    pub const fn expected_recall(self) -> f32 {
78        match self {
79            StageQuantLevel::BPS => 0.70,
80            StageQuantLevel::PQ => 0.90,
81            StageQuantLevel::I8 => 0.995,
82            StageQuantLevel::F32 => 1.00,
83        }
84    }
85}
86
87/// A single stage in the search pipeline.
88#[derive(Debug, Clone)]
89pub struct PipelineStage {
90    /// Quantization level for this stage.
91    pub quant_level: StageQuantLevel,
92    /// Number of candidates to consider.
93    pub input_candidates: usize,
94    /// Number of candidates to output.
95    pub output_candidates: usize,
96    /// Whether to apply filters at this stage.
97    pub apply_filter: bool,
98}
99
100impl PipelineStage {
101    /// Estimate the cost of this stage.
102    pub fn estimate_cost(&self, dimension: usize, cost_model: &CostModel) -> f32 {
103        let base_cost =
104            self.input_candidates as f32 * dimension as f32 * self.quant_level.relative_cost();
105        base_cost * cost_model.cpu_cycles_per_op
106    }
107
108    /// Estimate the recall of this stage.
109    pub fn estimate_recall(&self, total_vectors: usize) -> f32 {
110        let coverage = (self.input_candidates as f32 / total_vectors as f32).min(1.0);
111        self.quant_level.expected_recall() * coverage.sqrt()
112    }
113}
114
115/// Service Level Agreement for search.
116#[derive(Debug, Clone)]
117pub struct SearchSLA {
118    /// Target recall@k (0.0 to 1.0).
119    pub target_recall: f32,
120    /// Maximum latency budget.
121    pub latency_budget: Duration,
122    /// Maximum compute tokens (relative units).
123    pub token_budget: Option<u64>,
124    /// Optimization mode.
125    pub mode: OptimizationMode,
126}
127
128impl Default for SearchSLA {
129    fn default() -> Self {
130        Self {
131            target_recall: 0.95,
132            latency_budget: Duration::from_millis(10),
133            token_budget: None,
134            mode: OptimizationMode::Balanced,
135        }
136    }
137}
138
139/// Optimization mode for the planner.
140#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
141pub enum OptimizationMode {
142    /// Minimize latency (speed priority).
143    Speed,
144    /// Maximize recall (quality priority).
145    Quality,
146    /// Balance latency and recall.
147    #[default]
148    Balanced,
149    /// Strict SLO enforcement.
150    SLO,
151}
152
153/// Cost model parameters (calibrated per hardware).
154#[derive(Debug, Clone)]
155pub struct CostModel {
156    /// CPU cycles per operation (normalized).
157    pub cpu_cycles_per_op: f32,
158    /// Memory bandwidth in GB/s.
159    pub memory_bandwidth_gbps: f32,
160    /// L3 cache size in bytes.
161    pub l3_cache_bytes: usize,
162    /// Measured per-stage costs (ns per candidate).
163    pub stage_costs_ns: StageCosts,
164}
165
166/// Per-stage cost measurements.
167#[derive(Debug, Clone, Default)]
168pub struct StageCosts {
169    /// BPS scan cost per candidate (ns).
170    pub bps_per_candidate_ns: f32,
171    /// PQ scoring cost per candidate (ns).
172    pub pq_per_candidate_ns: f32,
173    /// I8 rerank cost per candidate (ns).
174    pub i8_per_candidate_ns: f32,
175    /// F32 rerank cost per candidate (ns).
176    pub f32_per_candidate_ns: f32,
177}
178
179impl Default for CostModel {
180    fn default() -> Self {
181        Self {
182            cpu_cycles_per_op: 1.0,
183            memory_bandwidth_gbps: 50.0,
184            l3_cache_bytes: 32 * 1024 * 1024, // 32 MB
185            stage_costs_ns: StageCosts {
186                bps_per_candidate_ns: 10.0,
187                pq_per_candidate_ns: 50.0,
188                i8_per_candidate_ns: 100.0,
189                f32_per_candidate_ns: 500.0,
190            },
191        }
192    }
193}
194
195/// Statistics about the dataset for planning.
196#[derive(Debug, Clone)]
197pub struct DatasetStats {
198    /// Total number of vectors.
199    pub total_vectors: usize,
200    /// Vector dimension.
201    pub dimension: usize,
202    /// Available quantization levels.
203    pub available_levels: Vec<StageQuantLevel>,
204    /// Filter selectivity (if known).
205    pub filter_selectivity: Option<f32>,
206    /// Recent query latency histogram (p50, p90, p99).
207    pub recent_latencies: Option<(Duration, Duration, Duration)>,
208}
209
210impl Default for DatasetStats {
211    fn default() -> Self {
212        Self {
213            total_vectors: 0,
214            dimension: 0,
215            available_levels: vec![StageQuantLevel::F32],
216            filter_selectivity: None,
217            recent_latencies: None,
218        }
219    }
220}
221
222/// The search plan: a complete specification for executing a search.
223#[derive(Debug, Clone)]
224pub struct SearchPlan {
225    /// Pipeline stages in execution order.
226    pub stages: Vec<PipelineStage>,
227    /// ef_search parameter for HNSW.
228    pub ef_search: usize,
229    /// Final k to return.
230    pub k: usize,
231    /// Whether to use batched expansion.
232    pub use_batched_expansion: bool,
233    /// Prefetch distance (0 = disabled).
234    pub prefetch_distance: usize,
235    /// Estimated total latency.
236    pub estimated_latency: Duration,
237    /// Estimated recall.
238    pub estimated_recall: f32,
239    /// Plan generation timestamp.
240    pub created_at: Instant,
241}
242
243impl SearchPlan {
244    /// Create a simple single-stage plan (F32 only).
245    pub fn simple(k: usize, ef_search: usize) -> Self {
246        Self {
247            stages: vec![PipelineStage {
248                quant_level: StageQuantLevel::F32,
249                input_candidates: ef_search,
250                output_candidates: k,
251                apply_filter: false,
252            }],
253            ef_search,
254            k,
255            use_batched_expansion: true,
256            prefetch_distance: 4,
257            estimated_latency: Duration::from_millis(1),
258            estimated_recall: 0.95,
259            created_at: Instant::now(),
260        }
261    }
262
263    /// Create a multi-stage plan with BPS → PQ → F32 pipeline.
264    pub fn multi_stage(k: usize, total_vectors: usize, target_recall: f32) -> Self {
265        // Calculate candidate counts for each stage
266        let coarse_candidates = (total_vectors as f32 * 0.1).min(10000.0) as usize;
267        let refine_candidates = (coarse_candidates as f32 * 0.1).max(k as f32 * 10.0) as usize;
268        let _rerank_candidates = (refine_candidates as f32 * 0.5).max(k as f32 * 2.0) as usize;
269
270        let mut stages = Vec::new();
271
272        // BPS coarse stage (if dataset is large enough)
273        if total_vectors > 10_000 {
274            stages.push(PipelineStage {
275                quant_level: StageQuantLevel::BPS,
276                input_candidates: total_vectors,
277                output_candidates: coarse_candidates,
278                apply_filter: true, // Early filter
279            });
280        }
281
282        // PQ refinement stage
283        if total_vectors > 1_000 {
284            stages.push(PipelineStage {
285                quant_level: StageQuantLevel::PQ,
286                input_candidates: coarse_candidates,
287                output_candidates: refine_candidates,
288                apply_filter: false,
289            });
290        }
291
292        // I8 or F32 rerank (choose based on recall target)
293        let rerank_level = if target_recall > 0.99 {
294            StageQuantLevel::F32
295        } else {
296            StageQuantLevel::I8
297        };
298
299        stages.push(PipelineStage {
300            quant_level: rerank_level,
301            input_candidates: refine_candidates,
302            output_candidates: k,
303            apply_filter: false,
304        });
305
306        Self {
307            stages,
308            ef_search: refine_candidates.max(64),
309            k,
310            use_batched_expansion: true,
311            prefetch_distance: 4,
312            estimated_latency: Duration::from_millis(5),
313            estimated_recall: target_recall,
314            created_at: Instant::now(),
315        }
316    }
317
318    /// Get the total estimated cost.
319    pub fn total_cost(&self, dimension: usize, cost_model: &CostModel) -> f32 {
320        self.stages
321            .iter()
322            .map(|s| s.estimate_cost(dimension, cost_model))
323            .sum()
324    }
325
326    /// Check if the plan meets the SLA.
327    pub fn meets_sla(&self, sla: &SearchSLA) -> bool {
328        self.estimated_recall >= sla.target_recall && self.estimated_latency <= sla.latency_budget
329    }
330}
331
332/// Search planner that generates optimal plans.
333pub struct SearchPlanner {
334    /// Cost model for estimation.
335    cost_model: CostModel,
336    /// Recent query statistics for adaptation.
337    recent_stats: RwLock<RecentStats>,
338    /// Query counter for bandit adaptation.
339    query_count: AtomicU64,
340}
341
342/// Recent query statistics for adaptive planning.
343#[derive(Debug, Default)]
344struct RecentStats {
345    /// Recent latencies (sliding window).
346    latencies: VecDeque<Duration>,
347    /// Recent recalls (sliding window).
348    recalls: VecDeque<f32>,
349    /// Window size.
350    window_size: usize,
351}
352
353impl RecentStats {
354    fn new(window_size: usize) -> Self {
355        Self {
356            latencies: VecDeque::with_capacity(window_size),
357            recalls: VecDeque::with_capacity(window_size),
358            window_size,
359        }
360    }
361
362    fn record(&mut self, latency: Duration, recall: f32) {
363        if self.latencies.len() >= self.window_size {
364            self.latencies.pop_front();
365            self.recalls.pop_front();
366        }
367        self.latencies.push_back(latency);
368        self.recalls.push_back(recall);
369    }
370
371    fn avg_latency(&self) -> Option<Duration> {
372        if self.latencies.is_empty() {
373            return None;
374        }
375        let sum: Duration = self.latencies.iter().sum();
376        Some(sum / self.latencies.len() as u32)
377    }
378
379    #[allow(dead_code)]
380    fn avg_recall(&self) -> Option<f32> {
381        if self.recalls.is_empty() {
382            return None;
383        }
384        Some(self.recalls.iter().sum::<f32>() / self.recalls.len() as f32)
385    }
386}
387
388impl SearchPlanner {
389    /// Create a new search planner with default cost model.
390    pub fn new() -> Self {
391        Self {
392            cost_model: CostModel::default(),
393            recent_stats: RwLock::new(RecentStats::new(100)),
394            query_count: AtomicU64::new(0),
395        }
396    }
397
398    /// Create a planner with custom cost model.
399    pub fn with_cost_model(cost_model: CostModel) -> Self {
400        Self {
401            cost_model,
402            recent_stats: RwLock::new(RecentStats::new(100)),
403            query_count: AtomicU64::new(0),
404        }
405    }
406
407    /// Generate an optimal search plan.
408    pub fn plan(&self, k: usize, sla: &SearchSLA, stats: &DatasetStats) -> SearchPlan {
409        self.query_count.fetch_add(1, Ordering::Relaxed);
410
411        // Choose planning strategy based on mode
412        match sla.mode {
413            OptimizationMode::Speed => self.plan_for_speed(k, sla, stats),
414            OptimizationMode::Quality => self.plan_for_quality(k, sla, stats),
415            OptimizationMode::Balanced => self.plan_balanced(k, sla, stats),
416            OptimizationMode::SLO => self.plan_for_slo(k, sla, stats),
417        }
418    }
419
420    /// Plan optimized for speed.
421    fn plan_for_speed(&self, k: usize, _sla: &SearchSLA, stats: &DatasetStats) -> SearchPlan {
422        // Use aggressive coarse filtering
423        let ef = k.max(16);
424
425        if stats.total_vectors > 100_000 && stats.available_levels.contains(&StageQuantLevel::BPS) {
426            // BPS → I8 pipeline
427            let coarse_count = (stats.total_vectors as f32 * 0.01).max(1000.0) as usize;
428
429            SearchPlan {
430                stages: vec![
431                    PipelineStage {
432                        quant_level: StageQuantLevel::BPS,
433                        input_candidates: stats.total_vectors,
434                        output_candidates: coarse_count,
435                        apply_filter: true,
436                    },
437                    PipelineStage {
438                        quant_level: StageQuantLevel::I8,
439                        input_candidates: coarse_count,
440                        output_candidates: k,
441                        apply_filter: false,
442                    },
443                ],
444                ef_search: ef,
445                k,
446                use_batched_expansion: true,
447                prefetch_distance: 8,
448                estimated_latency: Duration::from_micros(500),
449                estimated_recall: 0.85,
450                created_at: Instant::now(),
451            }
452        } else {
453            // Simple I8 or F32
454            let level = if stats.available_levels.contains(&StageQuantLevel::I8) {
455                StageQuantLevel::I8
456            } else {
457                StageQuantLevel::F32
458            };
459
460            SearchPlan {
461                stages: vec![PipelineStage {
462                    quant_level: level,
463                    input_candidates: ef * 4,
464                    output_candidates: k,
465                    apply_filter: true,
466                }],
467                ef_search: ef,
468                k,
469                use_batched_expansion: true,
470                prefetch_distance: 4,
471                estimated_latency: Duration::from_millis(1),
472                estimated_recall: 0.90,
473                created_at: Instant::now(),
474            }
475        }
476    }
477
478    /// Plan optimized for quality.
479    fn plan_for_quality(&self, k: usize, sla: &SearchSLA, stats: &DatasetStats) -> SearchPlan {
480        // Use full F32 with high ef
481        let ef = (k * 10).max(200);
482
483        SearchPlan {
484            stages: vec![PipelineStage {
485                quant_level: StageQuantLevel::F32,
486                input_candidates: ef,
487                output_candidates: k,
488                apply_filter: false, // Filter after scoring for max recall
489            }],
490            ef_search: ef,
491            k,
492            use_batched_expansion: true,
493            prefetch_distance: 4,
494            estimated_latency: self.estimate_latency(ef, stats.dimension, StageQuantLevel::F32),
495            estimated_recall: sla.target_recall.min(0.99),
496            created_at: Instant::now(),
497        }
498    }
499
500    /// Balanced plan.
501    fn plan_balanced(&self, k: usize, sla: &SearchSLA, stats: &DatasetStats) -> SearchPlan {
502        // Multi-stage with adaptive parameters
503        let ef = (k * 4).max(64);
504
505        // Decide stages based on dataset size and available levels
506        let use_pq =
507            stats.total_vectors > 10_000 && stats.available_levels.contains(&StageQuantLevel::PQ);
508        let use_i8 = stats.available_levels.contains(&StageQuantLevel::I8);
509
510        let mut stages = Vec::new();
511
512        if use_pq {
513            stages.push(PipelineStage {
514                quant_level: StageQuantLevel::PQ,
515                input_candidates: ef * 10,
516                output_candidates: ef * 2,
517                apply_filter: true,
518            });
519        }
520
521        let final_level = if sla.target_recall > 0.98 {
522            StageQuantLevel::F32
523        } else if use_i8 {
524            StageQuantLevel::I8
525        } else {
526            StageQuantLevel::F32
527        };
528
529        stages.push(PipelineStage {
530            quant_level: final_level,
531            input_candidates: if use_pq { ef * 2 } else { ef * 4 },
532            output_candidates: k,
533            apply_filter: !use_pq,
534        });
535
536        let estimated_recall = self.estimate_pipeline_recall(&stages, stats.total_vectors);
537        let estimated_latency = self.estimate_pipeline_latency(&stages, stats.dimension);
538
539        SearchPlan {
540            stages,
541            ef_search: ef,
542            k,
543            use_batched_expansion: true,
544            prefetch_distance: 4,
545            estimated_latency,
546            estimated_recall,
547            created_at: Instant::now(),
548        }
549    }
550
551    /// Plan for strict SLO enforcement.
552    fn plan_for_slo(&self, k: usize, sla: &SearchSLA, stats: &DatasetStats) -> SearchPlan {
553        // Use adaptive feedback from recent stats
554        let recent = self.recent_stats.read().unwrap();
555
556        let base_plan = if let Some(avg_latency) = recent.avg_latency() {
557            // Adjust based on recent performance
558            if avg_latency > sla.latency_budget {
559                // We're too slow, reduce work
560                self.plan_for_speed(k, sla, stats)
561            } else if avg_latency < sla.latency_budget / 2 {
562                // We have headroom, increase quality
563                self.plan_for_quality(k, sla, stats)
564            } else {
565                self.plan_balanced(k, sla, stats)
566            }
567        } else {
568            // No history, start balanced
569            self.plan_balanced(k, sla, stats)
570        };
571
572        // Ensure we meet SLA
573        if base_plan.estimated_latency > sla.latency_budget {
574            // Fall back to speed mode
575            self.plan_for_speed(k, sla, stats)
576        } else {
577            base_plan
578        }
579    }
580
581    /// Record query execution feedback for adaptation.
582    pub fn record_feedback(&self, latency: Duration, recall: f32) {
583        let mut stats = self.recent_stats.write().unwrap();
584        stats.record(latency, recall);
585    }
586
587    /// Estimate latency for a stage.
588    fn estimate_latency(
589        &self,
590        candidates: usize,
591        dimension: usize,
592        level: StageQuantLevel,
593    ) -> Duration {
594        let cost_per_candidate = match level {
595            StageQuantLevel::BPS => self.cost_model.stage_costs_ns.bps_per_candidate_ns,
596            StageQuantLevel::PQ => self.cost_model.stage_costs_ns.pq_per_candidate_ns,
597            StageQuantLevel::I8 => self.cost_model.stage_costs_ns.i8_per_candidate_ns,
598            StageQuantLevel::F32 => self.cost_model.stage_costs_ns.f32_per_candidate_ns,
599        };
600
601        let total_ns = candidates as f32 * cost_per_candidate * (dimension as f32 / 128.0);
602        Duration::from_nanos(total_ns as u64)
603    }
604
605    /// Estimate recall for a pipeline.
606    fn estimate_pipeline_recall(&self, stages: &[PipelineStage], total_vectors: usize) -> f32 {
607        stages
608            .iter()
609            .fold(1.0, |acc, stage| acc * stage.estimate_recall(total_vectors))
610    }
611
612    /// Estimate latency for a pipeline.
613    fn estimate_pipeline_latency(&self, stages: &[PipelineStage], dimension: usize) -> Duration {
614        stages
615            .iter()
616            .map(|stage| {
617                self.estimate_latency(stage.input_candidates, dimension, stage.quant_level)
618            })
619            .sum()
620    }
621
622    /// Get current cost model.
623    pub fn cost_model(&self) -> &CostModel {
624        &self.cost_model
625    }
626
627    /// Get query count.
628    pub fn query_count(&self) -> u64 {
629        self.query_count.load(Ordering::Relaxed)
630    }
631}
632
633impl Default for SearchPlanner {
634    fn default() -> Self {
635        Self::new()
636    }
637}
638
639/// Plan executor that runs a search plan.
640pub struct PlanExecutor;
641
642impl PlanExecutor {
643    /// Validate a plan before execution.
644    pub fn validate(plan: &SearchPlan) -> Result<(), PlanError> {
645        if plan.stages.is_empty() {
646            return Err(PlanError::EmptyPipeline);
647        }
648
649        if plan.k == 0 {
650            return Err(PlanError::InvalidK);
651        }
652
653        // Check stage consistency
654        for window in plan.stages.windows(2) {
655            if window[0].output_candidates < window[1].input_candidates {
656                // Allow some slack for over-request
657                if window[0].output_candidates * 2 < window[1].input_candidates {
658                    return Err(PlanError::StageOutputMismatch);
659                }
660            }
661        }
662
663        Ok(())
664    }
665}
666
667/// Plan validation errors.
668#[derive(Debug, Clone, PartialEq, Eq)]
669pub enum PlanError {
670    /// Pipeline has no stages.
671    EmptyPipeline,
672    /// k must be > 0.
673    InvalidK,
674    /// Stage output doesn't feed next stage input.
675    StageOutputMismatch,
676}
677
678impl std::fmt::Display for PlanError {
679    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
680        match self {
681            PlanError::EmptyPipeline => write!(f, "Pipeline has no stages"),
682            PlanError::InvalidK => write!(f, "k must be greater than 0"),
683            PlanError::StageOutputMismatch => {
684                write!(f, "Stage output doesn't match next stage input")
685            }
686        }
687    }
688}
689
690impl std::error::Error for PlanError {}
691
692#[cfg(test)]
693mod tests {
694    use super::*;
695
696    #[test]
697    fn test_simple_plan() {
698        let plan = SearchPlan::simple(10, 64);
699
700        assert_eq!(plan.k, 10);
701        assert_eq!(plan.ef_search, 64);
702        assert_eq!(plan.stages.len(), 1);
703        assert!(PlanExecutor::validate(&plan).is_ok());
704    }
705
706    #[test]
707    fn test_multi_stage_plan() {
708        let plan = SearchPlan::multi_stage(10, 1_000_000, 0.95);
709
710        assert_eq!(plan.k, 10);
711        assert!(plan.stages.len() >= 2);
712        assert!(PlanExecutor::validate(&plan).is_ok());
713    }
714
715    #[test]
716    fn test_planner_speed_mode() {
717        let planner = SearchPlanner::new();
718        let sla = SearchSLA {
719            mode: OptimizationMode::Speed,
720            ..Default::default()
721        };
722        let stats = DatasetStats {
723            total_vectors: 1_000_000,
724            dimension: 768,
725            available_levels: vec![
726                StageQuantLevel::BPS,
727                StageQuantLevel::I8,
728                StageQuantLevel::F32,
729            ],
730            ..Default::default()
731        };
732
733        let plan = planner.plan(10, &sla, &stats);
734
735        // Speed mode should use BPS for large datasets
736        assert!(
737            plan.stages
738                .iter()
739                .any(|s| s.quant_level == StageQuantLevel::BPS)
740        );
741        assert!(PlanExecutor::validate(&plan).is_ok());
742    }
743
744    #[test]
745    fn test_planner_quality_mode() {
746        let planner = SearchPlanner::new();
747        let sla = SearchSLA {
748            mode: OptimizationMode::Quality,
749            target_recall: 0.99,
750            ..Default::default()
751        };
752        let stats = DatasetStats {
753            total_vectors: 100_000,
754            dimension: 768,
755            available_levels: vec![StageQuantLevel::F32],
756            ..Default::default()
757        };
758
759        let plan = planner.plan(10, &sla, &stats);
760
761        // Quality mode should use F32
762        assert!(
763            plan.stages
764                .iter()
765                .any(|s| s.quant_level == StageQuantLevel::F32)
766        );
767        assert!(plan.ef_search >= 100);
768    }
769
770    #[test]
771    fn test_planner_balanced_mode() {
772        let planner = SearchPlanner::new();
773        let sla = SearchSLA {
774            mode: OptimizationMode::Balanced,
775            target_recall: 0.95,
776            ..Default::default()
777        };
778        let stats = DatasetStats {
779            total_vectors: 100_000,
780            dimension: 384,
781            available_levels: vec![
782                StageQuantLevel::PQ,
783                StageQuantLevel::I8,
784                StageQuantLevel::F32,
785            ],
786            ..Default::default()
787        };
788
789        let plan = planner.plan(10, &sla, &stats);
790
791        assert!(plan.stages.len() >= 1);
792        assert!(PlanExecutor::validate(&plan).is_ok());
793    }
794
795    #[test]
796    fn test_feedback_adaptation() {
797        let planner = SearchPlanner::new();
798
799        // Record some fast queries
800        for _ in 0..10 {
801            planner.record_feedback(Duration::from_micros(100), 0.98);
802        }
803
804        let sla = SearchSLA {
805            mode: OptimizationMode::SLO,
806            latency_budget: Duration::from_millis(5),
807            ..Default::default()
808        };
809        let stats = DatasetStats {
810            total_vectors: 100_000,
811            dimension: 384,
812            available_levels: vec![StageQuantLevel::F32],
813            ..Default::default()
814        };
815
816        // With fast recent queries, SLO mode should choose quality
817        let plan = planner.plan(10, &sla, &stats);
818        assert!(plan.ef_search >= 64);
819    }
820
821    #[test]
822    fn test_plan_cost_estimation() {
823        let plan = SearchPlan::simple(10, 64);
824        let cost_model = CostModel::default();
825
826        let cost = plan.total_cost(384, &cost_model);
827        assert!(cost > 0.0);
828    }
829
830    #[test]
831    fn test_plan_meets_sla() {
832        let plan = SearchPlan {
833            stages: vec![],
834            ef_search: 64,
835            k: 10,
836            use_batched_expansion: true,
837            prefetch_distance: 4,
838            estimated_latency: Duration::from_millis(2),
839            estimated_recall: 0.96,
840            created_at: Instant::now(),
841        };
842
843        let sla = SearchSLA {
844            target_recall: 0.95,
845            latency_budget: Duration::from_millis(5),
846            ..Default::default()
847        };
848
849        assert!(plan.meets_sla(&sla));
850
851        let strict_sla = SearchSLA {
852            target_recall: 0.99,
853            latency_budget: Duration::from_millis(1),
854            ..Default::default()
855        };
856
857        assert!(!plan.meets_sla(&strict_sla));
858    }
859
860    #[test]
861    fn test_invalid_plan() {
862        let empty_plan = SearchPlan {
863            stages: vec![],
864            ef_search: 64,
865            k: 10,
866            use_batched_expansion: true,
867            prefetch_distance: 4,
868            estimated_latency: Duration::from_millis(1),
869            estimated_recall: 0.95,
870            created_at: Instant::now(),
871        };
872
873        assert_eq!(
874            PlanExecutor::validate(&empty_plan),
875            Err(PlanError::EmptyPipeline)
876        );
877
878        let zero_k_plan = SearchPlan {
879            stages: vec![PipelineStage {
880                quant_level: StageQuantLevel::F32,
881                input_candidates: 64,
882                output_candidates: 0,
883                apply_filter: false,
884            }],
885            ef_search: 64,
886            k: 0,
887            use_batched_expansion: true,
888            prefetch_distance: 4,
889            estimated_latency: Duration::from_millis(1),
890            estimated_recall: 0.95,
891            created_at: Instant::now(),
892        };
893
894        assert_eq!(
895            PlanExecutor::validate(&zero_k_plan),
896            Err(PlanError::InvalidK)
897        );
898    }
899
900    #[test]
901    fn test_stage_relative_costs() {
902        assert!(StageQuantLevel::BPS.relative_cost() < StageQuantLevel::PQ.relative_cost());
903        assert!(StageQuantLevel::PQ.relative_cost() < StageQuantLevel::I8.relative_cost());
904        assert!(StageQuantLevel::I8.relative_cost() < StageQuantLevel::F32.relative_cost());
905    }
906}