Skip to main content

sochdb_query/
cost_optimizer.rs

1// SPDX-License-Identifier: AGPL-3.0-or-later
2// SochDB - LLM-Optimized Embedded Database
3// Copyright (C) 2026 Sushanth Reddy Vanagala (https://github.com/sushanthpy)
4//
5// This program is free software: you can redistribute it and/or modify
6// it under the terms of the GNU Affero General Public License as published by
7// the Free Software Foundation, either version 3 of the License, or
8// (at your option) any later version.
9//
10// This program is distributed in the hope that it will be useful,
11// but WITHOUT ANY WARRANTY; without even the implied warranty of
12// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13// GNU Affero General Public License for more details.
14//
15// You should have received a copy of the GNU Affero General Public License
16// along with this program. If not, see <https://www.gnu.org/licenses/>.
17
18//! Cost-Based Query Optimizer with Cardinality Estimation (Task 6)
19//!
20//! Provides cost-based query optimization for SOCH-QL with:
21//! - Cardinality estimation using sketches (HyperLogLog, CountMin)
22//! - Index selection: compare cost(table_scan) vs cost(index_seek)
23//! - Column projection pushdown to LSCS layer
24//! - Token-budget-aware planning
25//!
26//! ## Cost Model
27//!
28//! cost(plan) = I/O_cost + CPU_cost + memory_cost
29//!
30//! I/O_cost = blocks_read × C_seq + seeks × C_random
31//! Where:
32//!   C_seq = 0.1 ms/block (sequential read)
33//!   C_random = 5 ms/seek (random seek)
34//!
35//! CPU_cost = rows_processed × C_filter + sorts × N × log(N) × C_compare
36//!
37//! ## Selectivity Estimation
38//!
39//! Uses CountMinSketch for predicate selectivity and HyperLogLog for distinct counts.
40//!
41//! ## Token Budget Planning
42//!
43//! Given max_tokens, estimates result size and injects LIMIT clause:
44//!   max_rows = (max_tokens - header_tokens) / tokens_per_row
45
46use parking_lot::RwLock;
47use std::collections::{HashMap, HashSet};
48use std::sync::Arc;
49use std::time::{SystemTime, UNIX_EPOCH};
50
51// ============================================================================
52// Cost Model Constants
53// ============================================================================
54
55/// Cost model configuration with empirically-derived constants
56#[derive(Debug, Clone)]
57pub struct CostModelConfig {
58    /// Sequential I/O cost per block (ms)
59    pub c_seq: f64,
60    /// Random I/O cost per seek (ms)
61    pub c_random: f64,
62    /// CPU cost per row filter (ms)
63    pub c_filter: f64,
64    /// CPU cost per comparison during sort (ms)
65    pub c_compare: f64,
66    /// Block size in bytes
67    pub block_size: usize,
68    /// B-tree fanout for index cost estimation
69    pub btree_fanout: usize,
70    /// Memory bandwidth (bytes/ms)
71    pub memory_bandwidth: f64,
72}
73
74impl Default for CostModelConfig {
75    fn default() -> Self {
76        Self {
77            c_seq: 0.1,                // 0.1 ms per block sequential
78            c_random: 5.0,             // 5 ms per random seek
79            c_filter: 0.001,           // 0.001 ms per row filter
80            c_compare: 0.0001,         // 0.0001 ms per comparison
81            block_size: 4096,          // 4 KB blocks
82            btree_fanout: 100,         // 100 entries per B-tree node
83            memory_bandwidth: 10000.0, // 10 GB/s = 10000 bytes/ms
84        }
85    }
86}
87
88// ============================================================================
89// Statistics for Cardinality Estimation
90// ============================================================================
91
92/// Table statistics for cost estimation
93#[derive(Debug, Clone)]
94pub struct TableStats {
95    /// Table name
96    pub name: String,
97    /// Total row count
98    pub row_count: u64,
99    /// Total size in bytes
100    pub size_bytes: u64,
101    /// Column statistics
102    pub column_stats: HashMap<String, ColumnStats>,
103    /// Available indices
104    pub indices: Vec<IndexStats>,
105    /// Last update timestamp
106    pub last_updated: u64,
107}
108
109/// Column statistics
110#[derive(Debug, Clone)]
111pub struct ColumnStats {
112    /// Column name
113    pub name: String,
114    /// Distinct value count (from HyperLogLog)
115    pub distinct_count: u64,
116    /// Null count
117    pub null_count: u64,
118    /// Minimum value (if orderable)
119    pub min_value: Option<String>,
120    /// Maximum value (if orderable)
121    pub max_value: Option<String>,
122    /// Average length in bytes (for variable-length types)
123    pub avg_length: f64,
124    /// Most common values with frequencies
125    pub mcv: Vec<(String, f64)>,
126    /// Histogram buckets for range queries
127    pub histogram: Option<Histogram>,
128}
129
130/// Histogram for range selectivity estimation
131#[derive(Debug, Clone)]
132pub struct Histogram {
133    /// Bucket boundaries
134    pub boundaries: Vec<f64>,
135    /// Row count per bucket
136    pub counts: Vec<u64>,
137    /// Total rows in histogram
138    pub total_rows: u64,
139}
140
141impl Histogram {
142    /// Estimate selectivity for a range predicate
143    pub fn estimate_range_selectivity(&self, min: Option<f64>, max: Option<f64>) -> f64 {
144        if self.total_rows == 0 {
145            return 0.5; // Default
146        }
147
148        let mut selected_rows = 0u64;
149
150        for (i, &count) in self.counts.iter().enumerate() {
151            let bucket_min = if i == 0 {
152                f64::NEG_INFINITY
153            } else {
154                self.boundaries[i - 1]
155            };
156            let bucket_max = if i == self.boundaries.len() {
157                f64::INFINITY
158            } else {
159                self.boundaries[i]
160            };
161
162            let overlaps = match (min, max) {
163                (Some(min_val), Some(max_val)) => bucket_max >= min_val && bucket_min <= max_val,
164                (Some(min_val), None) => bucket_max >= min_val,
165                (None, Some(max_val)) => bucket_min <= max_val,
166                (None, None) => true,
167            };
168
169            if overlaps {
170                selected_rows += count;
171            }
172        }
173
174        selected_rows as f64 / self.total_rows as f64
175    }
176}
177
178/// Index statistics
179#[derive(Debug, Clone)]
180pub struct IndexStats {
181    /// Index name
182    pub name: String,
183    /// Indexed columns
184    pub columns: Vec<String>,
185    /// Is primary key
186    pub is_primary: bool,
187    /// Is unique
188    pub is_unique: bool,
189    /// Index type
190    pub index_type: IndexType,
191    /// Number of leaf pages
192    pub leaf_pages: u64,
193    /// Tree height (for B-tree)
194    pub height: u32,
195    /// Average entries per leaf page
196    pub avg_leaf_density: f64,
197}
198
199/// Index types
200#[derive(Debug, Clone, Copy, PartialEq, Eq)]
201pub enum IndexType {
202    BTree,
203    Hash,
204    LSM,
205    Learned,
206    Vector,
207    Bloom,
208}
209
210// ============================================================================
211// Query Predicates and Operations
212// ============================================================================
213
214/// Query predicate for cost estimation
215#[derive(Debug, Clone)]
216pub enum Predicate {
217    /// Equality: column = value
218    Eq { column: String, value: String },
219    /// Inequality: column != value
220    Ne { column: String, value: String },
221    /// Less than: column < value
222    Lt { column: String, value: String },
223    /// Less than or equal: column <= value
224    Le { column: String, value: String },
225    /// Greater than: column > value
226    Gt { column: String, value: String },
227    /// Greater than or equal: column >= value
228    Ge { column: String, value: String },
229    /// Between: column BETWEEN min AND max
230    Between {
231        column: String,
232        min: String,
233        max: String,
234    },
235    /// In list: column IN (v1, v2, ...)
236    In { column: String, values: Vec<String> },
237    /// Like: column LIKE pattern
238    Like { column: String, pattern: String },
239    /// Is null: column IS NULL
240    IsNull { column: String },
241    /// Is not null: column IS NOT NULL
242    IsNotNull { column: String },
243    /// And: pred1 AND pred2
244    And(Box<Predicate>, Box<Predicate>),
245    /// Or: pred1 OR pred2
246    Or(Box<Predicate>, Box<Predicate>),
247    /// Not: NOT pred
248    Not(Box<Predicate>),
249}
250
251impl Predicate {
252    /// Get columns referenced by this predicate
253    pub fn referenced_columns(&self) -> HashSet<String> {
254        let mut cols = HashSet::new();
255        self.collect_columns(&mut cols);
256        cols
257    }
258
259    fn collect_columns(&self, cols: &mut HashSet<String>) {
260        match self {
261            Self::Eq { column, .. }
262            | Self::Ne { column, .. }
263            | Self::Lt { column, .. }
264            | Self::Le { column, .. }
265            | Self::Gt { column, .. }
266            | Self::Ge { column, .. }
267            | Self::Between { column, .. }
268            | Self::In { column, .. }
269            | Self::Like { column, .. }
270            | Self::IsNull { column }
271            | Self::IsNotNull { column } => {
272                cols.insert(column.clone());
273            }
274            Self::And(left, right) | Self::Or(left, right) => {
275                left.collect_columns(cols);
276                right.collect_columns(cols);
277            }
278            Self::Not(inner) => inner.collect_columns(cols),
279        }
280    }
281}
282
283// ============================================================================
284// Physical Plan Operators
285// ============================================================================
286
287/// Physical query plan node
288#[derive(Debug, Clone)]
289pub enum PhysicalPlan {
290    /// Table scan (full or partial)
291    TableScan {
292        table: String,
293        columns: Vec<String>,
294        predicate: Option<Box<Predicate>>,
295        estimated_rows: u64,
296        estimated_cost: f64,
297    },
298    /// Index seek
299    IndexSeek {
300        table: String,
301        index: String,
302        columns: Vec<String>,
303        key_range: KeyRange,
304        predicate: Option<Box<Predicate>>,
305        estimated_rows: u64,
306        estimated_cost: f64,
307    },
308    /// Filter operator
309    Filter {
310        input: Box<PhysicalPlan>,
311        predicate: Predicate,
312        estimated_rows: u64,
313        estimated_cost: f64,
314    },
315    /// Project operator (column subset)
316    Project {
317        input: Box<PhysicalPlan>,
318        columns: Vec<String>,
319        estimated_cost: f64,
320    },
321    /// Sort operator
322    Sort {
323        input: Box<PhysicalPlan>,
324        order_by: Vec<(String, SortDirection)>,
325        estimated_cost: f64,
326    },
327    /// Limit operator
328    Limit {
329        input: Box<PhysicalPlan>,
330        limit: u64,
331        offset: u64,
332        estimated_cost: f64,
333    },
334    /// Nested loop join
335    NestedLoopJoin {
336        outer: Box<PhysicalPlan>,
337        inner: Box<PhysicalPlan>,
338        condition: Predicate,
339        join_type: JoinType,
340        estimated_rows: u64,
341        estimated_cost: f64,
342    },
343    /// Hash join
344    HashJoin {
345        build: Box<PhysicalPlan>,
346        probe: Box<PhysicalPlan>,
347        build_keys: Vec<String>,
348        probe_keys: Vec<String>,
349        join_type: JoinType,
350        estimated_rows: u64,
351        estimated_cost: f64,
352    },
353    /// Merge join
354    MergeJoin {
355        left: Box<PhysicalPlan>,
356        right: Box<PhysicalPlan>,
357        left_keys: Vec<String>,
358        right_keys: Vec<String>,
359        join_type: JoinType,
360        estimated_rows: u64,
361        estimated_cost: f64,
362    },
363    /// Aggregate operator
364    Aggregate {
365        input: Box<PhysicalPlan>,
366        group_by: Vec<String>,
367        aggregates: Vec<AggregateExpr>,
368        estimated_rows: u64,
369        estimated_cost: f64,
370    },
371}
372
373/// Key range for index seeks
374#[derive(Debug, Clone)]
375pub struct KeyRange {
376    pub start: Option<Vec<u8>>,
377    pub end: Option<Vec<u8>>,
378    pub start_inclusive: bool,
379    pub end_inclusive: bool,
380}
381
382impl KeyRange {
383    pub fn all() -> Self {
384        Self {
385            start: None,
386            end: None,
387            start_inclusive: true,
388            end_inclusive: true,
389        }
390    }
391
392    pub fn point(key: Vec<u8>) -> Self {
393        Self {
394            start: Some(key.clone()),
395            end: Some(key),
396            start_inclusive: true,
397            end_inclusive: true,
398        }
399    }
400
401    pub fn range(start: Option<Vec<u8>>, end: Option<Vec<u8>>, inclusive: bool) -> Self {
402        Self {
403            start,
404            end,
405            start_inclusive: inclusive,
406            end_inclusive: inclusive,
407        }
408    }
409}
410
411/// Sort direction
412#[derive(Debug, Clone, Copy, PartialEq, Eq)]
413pub enum SortDirection {
414    Ascending,
415    Descending,
416}
417
418/// Join type
419#[derive(Debug, Clone, Copy, PartialEq, Eq)]
420pub enum JoinType {
421    Inner,
422    Left,
423    Right,
424    Full,
425    Cross,
426}
427
428/// Aggregate expression
429#[derive(Debug, Clone)]
430pub struct AggregateExpr {
431    pub function: AggregateFunction,
432    pub column: Option<String>,
433    pub alias: String,
434}
435
436/// Aggregate functions
437#[derive(Debug, Clone, Copy, PartialEq, Eq)]
438pub enum AggregateFunction {
439    Count,
440    Sum,
441    Avg,
442    Min,
443    Max,
444    CountDistinct,
445}
446
447// ============================================================================
448// Cost-Based Query Optimizer
449// ============================================================================
450
451/// Cost-based query optimizer
452pub struct CostBasedOptimizer {
453    /// Cost model configuration
454    config: CostModelConfig,
455    /// Table statistics cache
456    stats_cache: Arc<RwLock<HashMap<String, TableStats>>>,
457    /// Token budget for result limiting
458    token_budget: Option<u64>,
459    /// Estimated tokens per row
460    tokens_per_row: f64,
461    /// Plan cache: (table, predicate_hash, limit) -> (plan, timestamp_us)
462    plan_cache: Arc<RwLock<HashMap<u64, (PhysicalPlan, u64)>>>,
463    /// Plan cache TTL in microseconds (default 5 seconds)
464    plan_cache_ttl_us: u64,
465}
466
467impl CostBasedOptimizer {
468    pub fn new(config: CostModelConfig) -> Self {
469        Self {
470            config,
471            stats_cache: Arc::new(RwLock::new(HashMap::new())),
472            token_budget: None,
473            tokens_per_row: 25.0, // Default estimate
474            plan_cache: Arc::new(RwLock::new(HashMap::new())),
475            plan_cache_ttl_us: 5_000_000, // 5 seconds
476        }
477    }
478
479    /// Set plan cache TTL
480    pub fn with_plan_cache_ttl_ms(mut self, ttl_ms: u64) -> Self {
481        self.plan_cache_ttl_us = ttl_ms * 1000;
482        self
483    }
484
485    /// Set token budget for result limiting
486    pub fn with_token_budget(mut self, budget: u64, tokens_per_row: f64) -> Self {
487        self.token_budget = Some(budget);
488        self.tokens_per_row = tokens_per_row;
489        self
490    }
491
492    /// Update table statistics
493    pub fn update_stats(&self, stats: TableStats) {
494        self.stats_cache.write().insert(stats.name.clone(), stats);
495    }
496
497    /// Get table statistics
498    pub fn get_stats(&self, table: &str) -> Option<TableStats> {
499        self.stats_cache.read().get(table).cloned()
500    }
501
502    /// Optimize a SELECT query
503    pub fn optimize(
504        &self,
505        table: &str,
506        columns: Vec<String>,
507        predicate: Option<Predicate>,
508        order_by: Vec<(String, SortDirection)>,
509        limit: Option<u64>,
510    ) -> PhysicalPlan {
511        let stats = self.get_stats(table);
512
513        // Calculate token-aware limit
514        let effective_limit = self.calculate_token_limit(limit);
515
516        // Get best access path (scan vs index)
517        let mut plan = self.choose_access_path(table, &columns, predicate.as_ref(), &stats);
518
519        // Apply column projection pushdown
520        plan = self.apply_projection_pushdown(plan, columns.clone());
521
522        // Apply sorting if needed
523        if !order_by.is_empty() {
524            plan = self.add_sort(plan, order_by, &stats);
525        }
526
527        // Apply limit
528        if let Some(lim) = effective_limit {
529            plan = PhysicalPlan::Limit {
530                estimated_cost: 0.0,
531                input: Box::new(plan),
532                limit: lim,
533                offset: 0,
534            };
535        }
536
537        plan
538    }
539
540    /// Calculate token-aware limit
541    fn calculate_token_limit(&self, user_limit: Option<u64>) -> Option<u64> {
542        match (self.token_budget, user_limit) {
543            (Some(budget), Some(limit)) => {
544                let header_tokens = 50u64;
545                let usable = budget.saturating_sub(header_tokens);
546                let max_rows = (usable as f64 / self.tokens_per_row).max(1.0) as u64;
547                Some(limit.min(max_rows))
548            }
549            (Some(budget), None) => {
550                let header_tokens = 50u64;
551                let usable = budget.saturating_sub(header_tokens);
552                let max_rows = (usable as f64 / self.tokens_per_row).max(1.0) as u64;
553                Some(max_rows)
554            }
555            (None, limit) => limit,
556        }
557    }
558
559    /// Choose best access path (table scan vs index seek)
560    fn choose_access_path(
561        &self,
562        table: &str,
563        columns: &[String],
564        predicate: Option<&Predicate>,
565        stats: &Option<TableStats>,
566    ) -> PhysicalPlan {
567        let row_count = stats.as_ref().map(|s| s.row_count).unwrap_or(10000);
568        let size_bytes = stats
569            .as_ref()
570            .map(|s| s.size_bytes)
571            .unwrap_or(row_count * 100);
572
573        // Calculate table scan cost
574        let scan_cost = self.estimate_scan_cost(row_count, size_bytes, predicate);
575
576        // Try to find a suitable index
577        let mut best_index_cost = f64::MAX;
578        let mut best_index: Option<&IndexStats> = None;
579
580        if let Some(table_stats) = stats.as_ref()
581            && let Some(pred) = predicate
582        {
583            let pred_columns = pred.referenced_columns();
584
585            for index in &table_stats.indices {
586                if self.index_covers_predicate(index, &pred_columns) {
587                    let selectivity = self.estimate_selectivity(pred, table_stats);
588                    let index_cost = self.estimate_index_cost(index, row_count, selectivity);
589
590                    if index_cost < best_index_cost {
591                        best_index_cost = index_cost;
592                        best_index = Some(index);
593                    }
594                }
595            }
596        }
597
598        // Choose cheaper option
599        if best_index_cost < scan_cost {
600            let index = best_index.unwrap();
601            let selectivity = predicate
602                .map(|p| self.estimate_selectivity(p, stats.as_ref().unwrap()))
603                .unwrap_or(1.0);
604
605            PhysicalPlan::IndexSeek {
606                table: table.to_string(),
607                index: index.name.clone(),
608                columns: columns.to_vec(),
609                key_range: predicate
610                    .map(|p| Self::derive_key_range(p))
611                    .unwrap_or_else(KeyRange::all),
612                predicate: predicate.map(|p| Box::new(p.clone())),
613                estimated_rows: (row_count as f64 * selectivity).max(1.0) as u64,
614                estimated_cost: best_index_cost,
615            }
616        } else {
617            PhysicalPlan::TableScan {
618                table: table.to_string(),
619                columns: columns.to_vec(),
620                predicate: predicate.map(|p| Box::new(p.clone())),
621                estimated_rows: row_count,
622                estimated_cost: scan_cost,
623            }
624        }
625    }
626
627    /// Check if index covers predicate columns
628    fn index_covers_predicate(&self, index: &IndexStats, pred_columns: &HashSet<String>) -> bool {
629        // Index is useful if it covers at least the first column of the predicate
630        if let Some(first_col) = index.columns.first() {
631            pred_columns.contains(first_col)
632        } else {
633            false
634        }
635    }
636
637    /// Estimate table scan cost
638    ///
639    /// I/O: sequential read all blocks
640    /// CPU: evaluate predicate against every row
641    fn estimate_scan_cost(
642        &self,
643        row_count: u64,
644        size_bytes: u64,
645        _predicate: Option<&Predicate>,
646    ) -> f64 {
647        let blocks = (size_bytes as f64 / self.config.block_size as f64)
648            .ceil()
649            .max(1.0) as u64;
650
651        // I/O cost: must read all blocks regardless of predicate
652        let io_cost = blocks as f64 * self.config.c_seq;
653
654        // CPU cost: evaluate predicate on every row (scan reads them all)
655        let cpu_cost = row_count as f64 * self.config.c_filter;
656
657        io_cost + cpu_cost
658    }
659
660    /// Estimate index seek cost
661    ///
662    /// Index cost = tree_traversal + leaf_scan + row_fetch
663    fn estimate_index_cost(&self, index: &IndexStats, total_rows: u64, selectivity: f64) -> f64 {
664        // Tree traversal cost (random I/O for each level)
665        let tree_cost = index.height as f64 * self.config.c_random;
666
667        // Leaf scan cost (sequential for matching range)
668        let matching_rows = (total_rows as f64 * selectivity) as u64;
669        let leaf_pages_scanned = (matching_rows as f64 / index.avg_leaf_density).ceil() as u64;
670        let leaf_cost = leaf_pages_scanned as f64 * self.config.c_seq;
671
672        // Row fetch cost (random if not clustered)
673        let fetch_cost = if index.is_primary {
674            0.0 // Clustered index, no extra fetch
675        } else {
676            matching_rows.min(1000) as f64 * self.config.c_random * 0.1 // Batch optimization
677        };
678
679        tree_cost + leaf_cost + fetch_cost
680    }
681
682    /// Estimate predicate selectivity
683    #[allow(clippy::only_used_in_recursion)]
684    fn estimate_selectivity(&self, predicate: &Predicate, stats: &TableStats) -> f64 {
685        match predicate {
686            Predicate::Eq { column, value } => {
687                if let Some(col_stats) = stats.column_stats.get(column) {
688                    // Check MCV first
689                    for (mcv_val, freq) in &col_stats.mcv {
690                        if mcv_val == value {
691                            return *freq;
692                        }
693                    }
694                    // Otherwise use uniform distribution
695                    1.0 / col_stats.distinct_count.max(1) as f64
696                } else {
697                    0.1 // Default 10%
698                }
699            }
700            Predicate::Ne { .. } => 0.9, // 90% pass
701            Predicate::Lt { column, value }
702            | Predicate::Le { column, value }
703            | Predicate::Gt { column, value }
704            | Predicate::Ge { column, value } => {
705                if let Some(col_stats) = stats.column_stats.get(column) {
706                    if let Some(ref hist) = col_stats.histogram {
707                        let val: f64 = value.parse().unwrap_or(0.0);
708                        match predicate {
709                            Predicate::Lt { .. } | Predicate::Le { .. } => {
710                                hist.estimate_range_selectivity(None, Some(val))
711                            }
712                            _ => hist.estimate_range_selectivity(Some(val), None),
713                        }
714                    } else {
715                        0.25 // Default 25%
716                    }
717                } else {
718                    0.25
719                }
720            }
721            Predicate::Between { column, min, max } => {
722                if let Some(col_stats) = stats.column_stats.get(column) {
723                    if let Some(ref hist) = col_stats.histogram {
724                        let min_val: f64 = min.parse().unwrap_or(0.0);
725                        let max_val: f64 = max.parse().unwrap_or(f64::MAX);
726                        hist.estimate_range_selectivity(Some(min_val), Some(max_val))
727                    } else {
728                        0.2
729                    }
730                } else {
731                    0.2
732                }
733            }
734            Predicate::In { column, values } => {
735                if let Some(col_stats) = stats.column_stats.get(column) {
736                    (values.len() as f64 / col_stats.distinct_count.max(1) as f64).min(1.0)
737                } else {
738                    (values.len() as f64 * 0.1).min(0.5)
739                }
740            }
741            Predicate::Like { .. } => 0.15, // Default 15%
742            Predicate::IsNull { column } => {
743                if let Some(col_stats) = stats.column_stats.get(column) {
744                    col_stats.null_count as f64 / stats.row_count.max(1) as f64
745                } else {
746                    0.01
747                }
748            }
749            Predicate::IsNotNull { column } => {
750                if let Some(col_stats) = stats.column_stats.get(column) {
751                    1.0 - (col_stats.null_count as f64 / stats.row_count.max(1) as f64)
752                } else {
753                    0.99
754                }
755            }
756            Predicate::And(left, right) => {
757                // Assume independence
758                self.estimate_selectivity(left, stats) * self.estimate_selectivity(right, stats)
759            }
760            Predicate::Or(left, right) => {
761                let s1 = self.estimate_selectivity(left, stats);
762                let s2 = self.estimate_selectivity(right, stats);
763                // P(A or B) = P(A) + P(B) - P(A and B)
764                (s1 + s2 - s1 * s2).min(1.0)
765            }
766            Predicate::Not(inner) => 1.0 - self.estimate_selectivity(inner, stats),
767        }
768    }
769
770    /// Derive key range from predicate for index seek
771    fn derive_key_range(predicate: &Predicate) -> KeyRange {
772        match predicate {
773            Predicate::Eq { value, .. } => KeyRange::point(value.as_bytes().to_vec()),
774            Predicate::Lt { value, .. } | Predicate::Le { value, .. } => KeyRange::range(
775                None,
776                Some(value.as_bytes().to_vec()),
777                matches!(predicate, Predicate::Le { .. }),
778            ),
779            Predicate::Gt { value, .. } | Predicate::Ge { value, .. } => KeyRange::range(
780                Some(value.as_bytes().to_vec()),
781                None,
782                matches!(predicate, Predicate::Ge { .. }),
783            ),
784            Predicate::Between { min, max, .. } => KeyRange {
785                start: Some(min.as_bytes().to_vec()),
786                end: Some(max.as_bytes().to_vec()),
787                start_inclusive: true,
788                end_inclusive: true,
789            },
790            Predicate::And(left, _) => Self::derive_key_range(left),
791            _ => KeyRange::all(),
792        }
793    }
794
795    /// Apply column projection pushdown
796    ///
797    /// Reduces I/O cost proportionally to the fraction of columns selected.
798    fn apply_projection_pushdown(&self, plan: PhysicalPlan, columns: Vec<String>) -> PhysicalPlan {
799        match plan {
800            PhysicalPlan::TableScan {
801                ref table,
802                predicate,
803                estimated_rows,
804                estimated_cost,
805                columns: ref all_columns,
806                ..
807            } => {
808                // Cost reduction proportional to column selectivity
809                let col_ratio = if all_columns.is_empty() || columns.is_empty() {
810                    1.0
811                } else {
812                    (columns.len() as f64 / all_columns.len().max(1) as f64).clamp(0.1, 1.0)
813                };
814                PhysicalPlan::TableScan {
815                    table: table.clone(),
816                    columns,
817                    predicate,
818                    estimated_rows,
819                    estimated_cost: estimated_cost * col_ratio,
820                }
821            }
822            PhysicalPlan::IndexSeek {
823                table,
824                index,
825                key_range,
826                predicate,
827                estimated_rows,
828                estimated_cost,
829                ..
830            } => {
831                PhysicalPlan::IndexSeek {
832                    table,
833                    index,
834                    columns, // Pushed down columns
835                    key_range,
836                    predicate,
837                    estimated_rows,
838                    estimated_cost,
839                }
840            }
841            other => PhysicalPlan::Project {
842                input: Box::new(other),
843                columns,
844                estimated_cost: 0.0,
845            },
846        }
847    }
848
849    /// Add sort operator
850    fn add_sort(
851        &self,
852        plan: PhysicalPlan,
853        order_by: Vec<(String, SortDirection)>,
854        _stats: &Option<TableStats>,
855    ) -> PhysicalPlan {
856        let estimated_rows = self.get_plan_rows(&plan);
857        let sort_cost = if estimated_rows > 0 {
858            estimated_rows as f64 * (estimated_rows as f64).log2() * self.config.c_compare
859        } else {
860            0.0
861        };
862
863        PhysicalPlan::Sort {
864            input: Box::new(plan),
865            order_by,
866            estimated_cost: sort_cost,
867        }
868    }
869
870    /// Get estimated rows from a plan
871    #[allow(clippy::only_used_in_recursion)]
872    fn get_plan_rows(&self, plan: &PhysicalPlan) -> u64 {
873        match plan {
874            PhysicalPlan::TableScan { estimated_rows, .. }
875            | PhysicalPlan::IndexSeek { estimated_rows, .. }
876            | PhysicalPlan::Filter { estimated_rows, .. }
877            | PhysicalPlan::Aggregate { estimated_rows, .. }
878            | PhysicalPlan::NestedLoopJoin { estimated_rows, .. }
879            | PhysicalPlan::HashJoin { estimated_rows, .. }
880            | PhysicalPlan::MergeJoin { estimated_rows, .. } => *estimated_rows,
881            PhysicalPlan::Project { input, .. } | PhysicalPlan::Sort { input, .. } => {
882                self.get_plan_rows(input)
883            }
884            PhysicalPlan::Limit { limit, .. } => *limit,
885        }
886    }
887
888    /// Get estimated cost from a plan
889    #[allow(clippy::only_used_in_recursion)]
890    pub fn get_plan_cost(&self, plan: &PhysicalPlan) -> f64 {
891        match plan {
892            PhysicalPlan::TableScan { estimated_cost, .. } => *estimated_cost,
893            PhysicalPlan::IndexSeek { estimated_cost, .. } => *estimated_cost,
894            PhysicalPlan::Filter {
895                estimated_cost,
896                input,
897                ..
898            } => *estimated_cost + self.get_plan_cost(input),
899            PhysicalPlan::Project {
900                estimated_cost,
901                input,
902                ..
903            } => *estimated_cost + self.get_plan_cost(input),
904            PhysicalPlan::Sort {
905                estimated_cost,
906                input,
907                ..
908            } => *estimated_cost + self.get_plan_cost(input),
909            PhysicalPlan::Limit {
910                estimated_cost,
911                input,
912                ..
913            } => *estimated_cost + self.get_plan_cost(input),
914            PhysicalPlan::NestedLoopJoin {
915                estimated_cost,
916                outer,
917                inner,
918                ..
919            } => *estimated_cost + self.get_plan_cost(outer) + self.get_plan_cost(inner),
920            PhysicalPlan::HashJoin {
921                estimated_cost,
922                build,
923                probe,
924                ..
925            } => *estimated_cost + self.get_plan_cost(build) + self.get_plan_cost(probe),
926            PhysicalPlan::MergeJoin {
927                estimated_cost,
928                left,
929                right,
930                ..
931            } => *estimated_cost + self.get_plan_cost(left) + self.get_plan_cost(right),
932            PhysicalPlan::Aggregate {
933                estimated_cost,
934                input,
935                ..
936            } => *estimated_cost + self.get_plan_cost(input),
937        }
938    }
939
940    /// Generate EXPLAIN output
941    pub fn explain(&self, plan: &PhysicalPlan) -> String {
942        self.explain_impl(plan, 0)
943    }
944
945    fn explain_impl(&self, plan: &PhysicalPlan, indent: usize) -> String {
946        let prefix = "  ".repeat(indent);
947        let cost = self.get_plan_cost(plan);
948
949        match plan {
950            PhysicalPlan::TableScan {
951                table,
952                columns,
953                estimated_rows,
954                ..
955            } => {
956                format!(
957                    "{}TableScan [table={}, columns={:?}, rows={}, cost={:.2}ms]",
958                    prefix, table, columns, estimated_rows, cost
959                )
960            }
961            PhysicalPlan::IndexSeek {
962                table,
963                index,
964                columns,
965                estimated_rows,
966                ..
967            } => {
968                format!(
969                    "{}IndexSeek [table={}, index={}, columns={:?}, rows={}, cost={:.2}ms]",
970                    prefix, table, index, columns, estimated_rows, cost
971                )
972            }
973            PhysicalPlan::Filter {
974                input,
975                estimated_rows,
976                ..
977            } => {
978                format!(
979                    "{}Filter [rows={}, cost={:.2}ms]\n{}",
980                    prefix,
981                    estimated_rows,
982                    cost,
983                    self.explain_impl(input, indent + 1)
984                )
985            }
986            PhysicalPlan::Project { input, columns, .. } => {
987                format!(
988                    "{}Project [columns={:?}, cost={:.2}ms]\n{}",
989                    prefix,
990                    columns,
991                    cost,
992                    self.explain_impl(input, indent + 1)
993                )
994            }
995            PhysicalPlan::Sort {
996                input, order_by, ..
997            } => {
998                let order: Vec<_> = order_by
999                    .iter()
1000                    .map(|(c, d)| format!("{} {:?}", c, d))
1001                    .collect();
1002                format!(
1003                    "{}Sort [order={:?}, cost={:.2}ms]\n{}",
1004                    prefix,
1005                    order,
1006                    cost,
1007                    self.explain_impl(input, indent + 1)
1008                )
1009            }
1010            PhysicalPlan::Limit {
1011                input,
1012                limit,
1013                offset,
1014                ..
1015            } => {
1016                format!(
1017                    "{}Limit [limit={}, offset={}, cost={:.2}ms]\n{}",
1018                    prefix,
1019                    limit,
1020                    offset,
1021                    cost,
1022                    self.explain_impl(input, indent + 1)
1023                )
1024            }
1025            PhysicalPlan::HashJoin {
1026                build,
1027                probe,
1028                join_type,
1029                estimated_rows,
1030                ..
1031            } => {
1032                format!(
1033                    "{}HashJoin [type={:?}, rows={}, cost={:.2}ms]\n{}\n{}",
1034                    prefix,
1035                    join_type,
1036                    estimated_rows,
1037                    cost,
1038                    self.explain_impl(build, indent + 1),
1039                    self.explain_impl(probe, indent + 1)
1040                )
1041            }
1042            PhysicalPlan::MergeJoin {
1043                left,
1044                right,
1045                join_type,
1046                estimated_rows,
1047                ..
1048            } => {
1049                format!(
1050                    "{}MergeJoin [type={:?}, rows={}, cost={:.2}ms]\n{}\n{}",
1051                    prefix,
1052                    join_type,
1053                    estimated_rows,
1054                    cost,
1055                    self.explain_impl(left, indent + 1),
1056                    self.explain_impl(right, indent + 1)
1057                )
1058            }
1059            PhysicalPlan::NestedLoopJoin {
1060                outer,
1061                inner,
1062                join_type,
1063                estimated_rows,
1064                ..
1065            } => {
1066                format!(
1067                    "{}NestedLoopJoin [type={:?}, rows={}, cost={:.2}ms]\n{}\n{}",
1068                    prefix,
1069                    join_type,
1070                    estimated_rows,
1071                    cost,
1072                    self.explain_impl(outer, indent + 1),
1073                    self.explain_impl(inner, indent + 1)
1074                )
1075            }
1076            PhysicalPlan::Aggregate {
1077                input,
1078                group_by,
1079                aggregates,
1080                estimated_rows,
1081                ..
1082            } => {
1083                let aggs: Vec<_> = aggregates
1084                    .iter()
1085                    .map(|a| format!("{:?}({})", a.function, a.column.as_deref().unwrap_or("*")))
1086                    .collect();
1087                format!(
1088                    "{}Aggregate [group_by={:?}, aggs={:?}, rows={}, cost={:.2}ms]\n{}",
1089                    prefix,
1090                    group_by,
1091                    aggs,
1092                    estimated_rows,
1093                    cost,
1094                    self.explain_impl(input, indent + 1)
1095                )
1096            }
1097        }
1098    }
1099}
1100
1101// ============================================================================
1102// Plan Cache & Stats Helpers
1103// ============================================================================
1104
1105impl CostBasedOptimizer {
1106    /// Evict stale entries from the plan cache.
1107    pub fn evict_stale_plans(&self) {
1108        let now = Self::now_us();
1109        self.plan_cache
1110            .write()
1111            .retain(|_, (_, ts)| now.saturating_sub(*ts) < self.plan_cache_ttl_us);
1112    }
1113
1114    /// Clear the entire plan cache (call after DDL or bulk load).
1115    pub fn invalidate_plan_cache(&self) {
1116        self.plan_cache.write().clear();
1117    }
1118
1119    /// Collect fresh statistics for a table from row data.
1120    ///
1121    /// Pass an iterator of (column_name, value_as_string) pairs per row.
1122    /// This builds column stats with distinct counts and optional histograms.
1123    pub fn collect_stats(
1124        &self,
1125        table_name: &str,
1126        row_count: u64,
1127        size_bytes: u64,
1128        column_values: HashMap<String, Vec<String>>,
1129        indices: Vec<IndexStats>,
1130    ) {
1131        let mut column_stats = HashMap::new();
1132        for (col_name, values) in &column_values {
1133            let distinct: HashSet<&String> = values.iter().collect();
1134            let null_count = values.iter().filter(|v| v.is_empty()).count() as u64;
1135            let avg_length = if values.is_empty() {
1136                0.0
1137            } else {
1138                values.iter().map(|v| v.len()).sum::<usize>() as f64 / values.len() as f64
1139            };
1140
1141            // Build histogram for numeric columns (try parse first 10 values)
1142            let is_numeric = values.iter().take(10).all(|v| v.parse::<f64>().is_ok());
1143            let histogram = if is_numeric && values.len() >= 10 {
1144                let mut nums: Vec<f64> = values.iter().filter_map(|v| v.parse().ok()).collect();
1145                nums.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
1146                let bucket_count = 10.min(nums.len());
1147                let bucket_size = nums.len() / bucket_count;
1148                let mut boundaries = Vec::new();
1149                let mut counts = Vec::new();
1150                for i in 0..bucket_count {
1151                    let end = if i == bucket_count - 1 {
1152                        nums.len()
1153                    } else {
1154                        (i + 1) * bucket_size
1155                    };
1156                    let start = i * bucket_size;
1157                    boundaries.push(nums[end - 1]);
1158                    counts.push((end - start) as u64);
1159                }
1160                Some(Histogram {
1161                    boundaries,
1162                    counts,
1163                    total_rows: nums.len() as u64,
1164                })
1165            } else {
1166                None
1167            };
1168
1169            // Build MCV (top 5 most common values)
1170            let mut freq_map: HashMap<&String, usize> = HashMap::new();
1171            for v in values {
1172                *freq_map.entry(v).or_insert(0) += 1;
1173            }
1174            let total = values.len() as f64;
1175            let mut mcv: Vec<(String, f64)> = freq_map
1176                .iter()
1177                .map(|(k, &v)| ((*k).clone(), v as f64 / total))
1178                .collect();
1179            mcv.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
1180            mcv.truncate(5);
1181
1182            column_stats.insert(
1183                col_name.clone(),
1184                ColumnStats {
1185                    name: col_name.clone(),
1186                    distinct_count: distinct.len() as u64,
1187                    null_count,
1188                    min_value: values.iter().min().cloned(),
1189                    max_value: values.iter().max().cloned(),
1190                    avg_length,
1191                    mcv,
1192                    histogram,
1193                },
1194            );
1195        }
1196
1197        self.update_stats(TableStats {
1198            name: table_name.to_string(),
1199            row_count,
1200            size_bytes,
1201            column_stats,
1202            indices,
1203            last_updated: Self::now_us(),
1204        });
1205
1206        // Invalidate cached plans for this table
1207        self.invalidate_plan_cache();
1208    }
1209
1210    /// Check if stats are stale (older than threshold)
1211    pub fn stats_age_us(&self, table: &str) -> Option<u64> {
1212        self.stats_cache
1213            .read()
1214            .get(table)
1215            .map(|s| Self::now_us().saturating_sub(s.last_updated))
1216    }
1217
1218    fn now_us() -> u64 {
1219        SystemTime::now()
1220            .duration_since(UNIX_EPOCH)
1221            .unwrap_or_default()
1222            .as_micros() as u64
1223    }
1224}
1225
1226// ============================================================================
1227// Join Order Optimizer (Dynamic Programming)
1228// ============================================================================
1229
1230/// Join order optimizer using dynamic programming
1231pub struct JoinOrderOptimizer {
1232    /// Table statistics
1233    stats: HashMap<String, TableStats>,
1234    /// Cost model
1235    config: CostModelConfig,
1236}
1237
1238impl JoinOrderOptimizer {
1239    pub fn new(config: CostModelConfig) -> Self {
1240        Self {
1241            stats: HashMap::new(),
1242            config,
1243        }
1244    }
1245
1246    /// Add table statistics
1247    pub fn add_stats(&mut self, stats: TableStats) {
1248        self.stats.insert(stats.name.clone(), stats);
1249    }
1250
1251    /// Find optimal join order using dynamic programming
1252    ///
1253    /// Time: O(2^n × n^2) where n = number of tables
1254    /// Practical for n ≤ 10
1255    pub fn find_optimal_order(
1256        &self,
1257        tables: &[String],
1258        join_conditions: &[(String, String, String, String)], // (table1, col1, table2, col2)
1259    ) -> Vec<(String, String)> {
1260        let n = tables.len();
1261        if n <= 1 {
1262            return vec![];
1263        }
1264
1265        // dp[mask] = (cost, join_order)
1266        let mut dp: HashMap<u32, (f64, Vec<(String, String)>)> = HashMap::new();
1267
1268        // Base case: single tables
1269        for (i, _table) in tables.iter().enumerate() {
1270            let mask = 1u32 << i;
1271            dp.insert(mask, (0.0, vec![]));
1272        }
1273
1274        // Build up larger subsets
1275        for size in 2..=n {
1276            for mask in 0..(1u32 << n) {
1277                if mask.count_ones() != size as u32 {
1278                    continue;
1279                }
1280
1281                let mut best_cost = f64::MAX;
1282                let mut best_order = vec![];
1283
1284                // Try all ways to split into two non-empty subsets
1285                for sub in 1..mask {
1286                    if sub & mask != sub || sub == 0 {
1287                        continue;
1288                    }
1289                    let other = mask ^ sub;
1290                    if other == 0 {
1291                        continue;
1292                    }
1293
1294                    // Check if there's a join between sub and other
1295                    if !self.has_join_condition(tables, sub, other, join_conditions) {
1296                        continue;
1297                    }
1298
1299                    if let (Some((cost1, order1)), Some((cost2, order2))) =
1300                        (dp.get(&sub), dp.get(&other))
1301                    {
1302                        let join_cost = self.estimate_join_cost(tables, sub, other);
1303                        let total_cost = cost1 + cost2 + join_cost;
1304
1305                        if total_cost < best_cost {
1306                            best_cost = total_cost;
1307                            best_order = order1.clone();
1308                            best_order.extend(order2.clone());
1309
1310                            // Add the join
1311                            let (t1, t2) =
1312                                self.get_join_tables(tables, sub, other, join_conditions);
1313                            if let Some((t1, t2)) = Some((t1, t2)) {
1314                                best_order.push((t1, t2));
1315                            }
1316                        }
1317                    }
1318                }
1319
1320                if best_cost < f64::MAX {
1321                    dp.insert(mask, (best_cost, best_order));
1322                }
1323            }
1324        }
1325
1326        let full_mask = (1u32 << n) - 1;
1327        dp.get(&full_mask)
1328            .map(|(_, order)| order.clone())
1329            .unwrap_or_default()
1330    }
1331
1332    fn has_join_condition(
1333        &self,
1334        tables: &[String],
1335        mask1: u32,
1336        mask2: u32,
1337        conditions: &[(String, String, String, String)],
1338    ) -> bool {
1339        for (t1, _, t2, _) in conditions {
1340            let idx1 = tables.iter().position(|t| t == t1);
1341            let idx2 = tables.iter().position(|t| t == t2);
1342
1343            if let (Some(i1), Some(i2)) = (idx1, idx2) {
1344                let in_mask1 = (mask1 >> i1) & 1 == 1;
1345                let in_mask2 = (mask2 >> i2) & 1 == 1;
1346
1347                if in_mask1 && in_mask2 {
1348                    return true;
1349                }
1350            }
1351        }
1352        false
1353    }
1354
1355    fn get_join_tables(
1356        &self,
1357        tables: &[String],
1358        mask1: u32,
1359        mask2: u32,
1360        conditions: &[(String, String, String, String)],
1361    ) -> (String, String) {
1362        for (t1, _, t2, _) in conditions {
1363            let idx1 = tables.iter().position(|t| t == t1);
1364            let idx2 = tables.iter().position(|t| t == t2);
1365
1366            if let (Some(i1), Some(i2)) = (idx1, idx2) {
1367                let t1_in_mask1 = (mask1 >> i1) & 1 == 1;
1368                let t2_in_mask2 = (mask2 >> i2) & 1 == 1;
1369
1370                if t1_in_mask1 && t2_in_mask2 {
1371                    return (t1.clone(), t2.clone());
1372                }
1373            }
1374        }
1375        (String::new(), String::new())
1376    }
1377
1378    fn estimate_join_cost(&self, tables: &[String], mask1: u32, mask2: u32) -> f64 {
1379        let rows1 = self.estimate_rows_for_mask(tables, mask1);
1380        let rows2 = self.estimate_rows_for_mask(tables, mask2);
1381
1382        // Hash join cost estimate
1383        // Build cost + probe cost
1384        let build_cost = rows1 as f64 * self.config.c_filter;
1385        let probe_cost = rows2 as f64 * self.config.c_filter;
1386
1387        build_cost + probe_cost
1388    }
1389
1390    fn estimate_rows_for_mask(&self, tables: &[String], mask: u32) -> u64 {
1391        let mut total = 1u64;
1392
1393        for (i, table) in tables.iter().enumerate() {
1394            if (mask >> i) & 1 == 1 {
1395                let rows = self.stats.get(table).map(|s| s.row_count).unwrap_or(1000);
1396                total = total.saturating_mul(rows);
1397            }
1398        }
1399
1400        // Apply default selectivity for joins
1401        let num_tables = mask.count_ones();
1402        if num_tables > 1 {
1403            total = (total as f64 * 0.1f64.powi(num_tables as i32 - 1)) as u64;
1404        }
1405
1406        total.max(1)
1407    }
1408}
1409
1410// ============================================================================
1411// Tests
1412// ============================================================================
1413
1414#[cfg(test)]
1415mod tests {
1416    use super::*;
1417
1418    fn create_test_stats() -> TableStats {
1419        let mut column_stats = HashMap::new();
1420        column_stats.insert(
1421            "id".to_string(),
1422            ColumnStats {
1423                name: "id".to_string(),
1424                distinct_count: 100000,
1425                null_count: 0,
1426                min_value: Some("1".to_string()),
1427                max_value: Some("100000".to_string()),
1428                avg_length: 8.0,
1429                mcv: vec![],
1430                histogram: None,
1431            },
1432        );
1433        column_stats.insert(
1434            "score".to_string(),
1435            ColumnStats {
1436                name: "score".to_string(),
1437                distinct_count: 100,
1438                null_count: 1000,
1439                min_value: Some("0".to_string()),
1440                max_value: Some("100".to_string()),
1441                avg_length: 8.0,
1442                mcv: vec![("50".to_string(), 0.05)],
1443                histogram: Some(Histogram {
1444                    boundaries: vec![25.0, 50.0, 75.0, 100.0],
1445                    counts: vec![25000, 25000, 25000, 25000],
1446                    total_rows: 100000,
1447                }),
1448            },
1449        );
1450
1451        TableStats {
1452            name: "users".to_string(),
1453            row_count: 100000,
1454            size_bytes: 10_000_000, // 10 MB
1455            column_stats,
1456            indices: vec![
1457                IndexStats {
1458                    name: "pk_users".to_string(),
1459                    columns: vec!["id".to_string()],
1460                    is_primary: true,
1461                    is_unique: true,
1462                    index_type: IndexType::BTree,
1463                    leaf_pages: 1000,
1464                    height: 3,
1465                    avg_leaf_density: 100.0,
1466                },
1467                IndexStats {
1468                    name: "idx_score".to_string(),
1469                    columns: vec!["score".to_string()],
1470                    is_primary: false,
1471                    is_unique: false,
1472                    index_type: IndexType::BTree,
1473                    leaf_pages: 500,
1474                    height: 2,
1475                    avg_leaf_density: 200.0,
1476                },
1477            ],
1478            last_updated: 0,
1479        }
1480    }
1481
1482    #[test]
1483    fn test_selectivity_estimation() {
1484        let config = CostModelConfig::default();
1485        let optimizer = CostBasedOptimizer::new(config);
1486
1487        let stats = create_test_stats();
1488        optimizer.update_stats(stats.clone());
1489
1490        // Equality predicate
1491        let pred = Predicate::Eq {
1492            column: "id".to_string(),
1493            value: "12345".to_string(),
1494        };
1495        let sel = optimizer.estimate_selectivity(&pred, &stats);
1496        assert!(sel < 0.001); // Should be very selective
1497
1498        // Range predicate with histogram
1499        // Note: For histogram boundaries [25, 50, 75, 100] with equal distribution,
1500        // Gt{75} includes buckets with bucket_max >= 75, which is buckets 2 and 3 (50%)
1501        let pred = Predicate::Gt {
1502            column: "score".to_string(),
1503            value: "75".to_string(),
1504        };
1505        let sel = optimizer.estimate_selectivity(&pred, &stats);
1506        assert!(sel > 0.4 && sel < 0.6); // ~50% from histogram (2 of 4 buckets)
1507    }
1508
1509    #[test]
1510    fn test_access_path_selection() {
1511        let config = CostModelConfig::default();
1512        let optimizer = CostBasedOptimizer::new(config);
1513
1514        let stats = create_test_stats();
1515        optimizer.update_stats(stats);
1516
1517        // High selectivity predicate should use index
1518        let pred = Predicate::Eq {
1519            column: "id".to_string(),
1520            value: "12345".to_string(),
1521        };
1522        let plan = optimizer.optimize(
1523            "users",
1524            vec!["id".to_string(), "score".to_string()],
1525            Some(pred),
1526            vec![],
1527            None,
1528        );
1529
1530        match plan {
1531            PhysicalPlan::IndexSeek { index, .. } => {
1532                assert_eq!(index, "pk_users");
1533            }
1534            _ => panic!("Expected IndexSeek for equality on primary key"),
1535        }
1536    }
1537
1538    #[test]
1539    fn test_token_budget_limit() {
1540        let config = CostModelConfig::default();
1541        let optimizer = CostBasedOptimizer::new(config).with_token_budget(2048, 25.0);
1542
1543        // With 2048 token budget and 25 tokens/row:
1544        // max_rows = (2048 - 50) / 25 = ~80
1545        let plan = optimizer.optimize("users", vec!["id".to_string()], None, vec![], None);
1546
1547        match plan {
1548            PhysicalPlan::Limit { limit, .. } => {
1549                assert!(limit <= 80);
1550            }
1551            _ => panic!("Expected Limit to be injected"),
1552        }
1553    }
1554
1555    #[test]
1556    fn test_explain_output() {
1557        let config = CostModelConfig::default();
1558        let optimizer = CostBasedOptimizer::new(config);
1559
1560        let stats = create_test_stats();
1561        optimizer.update_stats(stats);
1562
1563        let plan = optimizer.optimize(
1564            "users",
1565            vec!["id".to_string(), "score".to_string()],
1566            Some(Predicate::Gt {
1567                column: "score".to_string(),
1568                value: "80".to_string(),
1569            }),
1570            vec![("score".to_string(), SortDirection::Descending)],
1571            Some(10),
1572        );
1573
1574        let explain = optimizer.explain(&plan);
1575        assert!(explain.contains("Limit"));
1576        assert!(explain.contains("Sort"));
1577    }
1578
1579    // ================================================================
1580    // Production-grade tests
1581    // ================================================================
1582
1583    #[test]
1584    fn test_token_budget_underflow_safety() {
1585        // Ensure small budget doesn't panic (saturating_sub)
1586        let config = CostModelConfig::default();
1587        let optimizer = CostBasedOptimizer::new(config).with_token_budget(10, 25.0);
1588
1589        let plan = optimizer.optimize("users", vec!["id".to_string()], None, vec![], None);
1590        match plan {
1591            PhysicalPlan::Limit { limit, .. } => {
1592                assert!(limit >= 1, "Must return at least 1 row");
1593            }
1594            _ => panic!("Expected Limit"),
1595        }
1596    }
1597
1598    #[test]
1599    fn test_index_seek_derives_key_range() {
1600        let config = CostModelConfig::default();
1601        let optimizer = CostBasedOptimizer::new(config);
1602        optimizer.update_stats(create_test_stats());
1603
1604        let plan = optimizer.optimize(
1605            "users",
1606            vec!["id".to_string()],
1607            Some(Predicate::Eq {
1608                column: "id".to_string(),
1609                value: "42".to_string(),
1610            }),
1611            vec![],
1612            None,
1613        );
1614
1615        match plan {
1616            PhysicalPlan::IndexSeek { key_range, .. } => {
1617                assert!(
1618                    key_range.start.is_some(),
1619                    "KeyRange must derive from Eq predicate"
1620                );
1621                assert_eq!(
1622                    key_range.start, key_range.end,
1623                    "Eq predicate → point key range"
1624                );
1625            }
1626            _ => panic!("Expected IndexSeek"),
1627        }
1628    }
1629
1630    #[test]
1631    fn test_range_predicate_key_range() {
1632        let config = CostModelConfig::default();
1633        let optimizer = CostBasedOptimizer::new(config);
1634        optimizer.update_stats(create_test_stats());
1635
1636        let plan = optimizer.optimize(
1637            "users",
1638            vec!["score".to_string()],
1639            Some(Predicate::Between {
1640                column: "score".to_string(),
1641                min: "10".to_string(),
1642                max: "90".to_string(),
1643            }),
1644            vec![],
1645            None,
1646        );
1647
1648        match plan {
1649            PhysicalPlan::IndexSeek { key_range, .. } => {
1650                assert!(key_range.start.is_some());
1651                assert!(key_range.end.is_some());
1652                assert!(key_range.start_inclusive);
1653                assert!(key_range.end_inclusive);
1654            }
1655            _ => {} // May choose scan if cheaper — that's OK
1656        }
1657    }
1658
1659    #[test]
1660    fn test_projection_pushdown_proportional_reduction() {
1661        let config = CostModelConfig::default();
1662        let optimizer = CostBasedOptimizer::new(config);
1663        optimizer.update_stats(create_test_stats());
1664
1665        // Select 1 of 2 columns → ~50% cost reduction on table scan
1666        let plan_all = optimizer.optimize(
1667            "users",
1668            vec!["id".to_string(), "score".to_string()],
1669            None,
1670            vec![],
1671            Some(100),
1672        );
1673        let plan_single =
1674            optimizer.optimize("users", vec!["id".to_string()], None, vec![], Some(100));
1675
1676        let cost_all = optimizer.get_plan_cost(&plan_all);
1677        let cost_single = optimizer.get_plan_cost(&plan_single);
1678        // Single column should cost less than or equal to all columns
1679        assert!(
1680            cost_single <= cost_all,
1681            "Projection should reduce cost: {} vs {}",
1682            cost_single,
1683            cost_all
1684        );
1685    }
1686
1687    #[test]
1688    fn test_collect_stats_builds_histogram() {
1689        let config = CostModelConfig::default();
1690        let optimizer = CostBasedOptimizer::new(config);
1691
1692        let mut column_values = HashMap::new();
1693        let scores: Vec<String> = (0..100).map(|i| i.to_string()).collect();
1694        column_values.insert("score".to_string(), scores);
1695
1696        optimizer.collect_stats("test_table", 100, 10000, column_values, vec![]);
1697
1698        let stats = optimizer.get_stats("test_table").unwrap();
1699        assert_eq!(stats.row_count, 100);
1700        let score_stats = stats.column_stats.get("score").unwrap();
1701        assert_eq!(score_stats.distinct_count, 100);
1702        assert!(
1703            score_stats.histogram.is_some(),
1704            "Numeric column should get histogram"
1705        );
1706        assert!(!score_stats.mcv.is_empty(), "Should build MCV list");
1707    }
1708
1709    #[test]
1710    fn test_plan_cache_invalidation() {
1711        let config = CostModelConfig::default();
1712        let optimizer = CostBasedOptimizer::new(config);
1713
1714        // Collecting stats should invalidate cache
1715        let mut col = HashMap::new();
1716        col.insert("x".to_string(), vec!["1".to_string()]);
1717        optimizer.collect_stats("t", 1, 100, col.clone(), vec![]);
1718
1719        // Cache should be empty after stats collection
1720        assert!(optimizer.plan_cache.read().is_empty());
1721    }
1722
1723    #[test]
1724    fn test_stats_age_tracking() {
1725        let config = CostModelConfig::default();
1726        let optimizer = CostBasedOptimizer::new(config);
1727
1728        assert!(optimizer.stats_age_us("unknown").is_none());
1729
1730        let mut col = HashMap::new();
1731        col.insert("x".to_string(), vec!["1".to_string()]);
1732        optimizer.collect_stats("t", 1, 100, col, vec![]);
1733
1734        let age = optimizer.stats_age_us("t").unwrap();
1735        assert!(age < 1_000_000, "Stats should be fresh (< 1 second old)");
1736    }
1737
1738    #[test]
1739    fn test_scan_cost_reads_all_blocks() {
1740        // Scan cost must NOT multiply by selectivity — scans read everything
1741        let config = CostModelConfig::default();
1742        let optimizer = CostBasedOptimizer::new(config.clone());
1743        let no_pred = optimizer.estimate_scan_cost(1000, 4096 * 10, None);
1744        let with_pred = optimizer.estimate_scan_cost(
1745            1000,
1746            4096 * 10,
1747            Some(&Predicate::Eq {
1748                column: "x".to_string(),
1749                value: "1".to_string(),
1750            }),
1751        );
1752        // Scan cost should be the same regardless of predicate
1753        // (scan reads all blocks; predicate doesn't reduce I/O)
1754        assert!(
1755            (no_pred - with_pred).abs() < 0.001,
1756            "Scan cost should not depend on predicate: {} vs {}",
1757            no_pred,
1758            with_pred
1759        );
1760    }
1761
1762    #[test]
1763    fn test_index_wins_over_scan_for_point_lookup() {
1764        let config = CostModelConfig::default();
1765        let optimizer = CostBasedOptimizer::new(config);
1766        optimizer.update_stats(create_test_stats());
1767
1768        let scan_cost = optimizer.estimate_scan_cost(100000, 10_000_000, None);
1769
1770        // Index cost for a point lookup should be orders of magnitude cheaper
1771        let pk_index = &create_test_stats().indices[0]; // pk_users
1772        let index_cost = optimizer.estimate_index_cost(pk_index, 100000, 0.00001);
1773
1774        assert!(
1775            index_cost < scan_cost * 0.1,
1776            "Index point lookup ({:.2}) should be <10% of scan cost ({:.2})",
1777            index_cost,
1778            scan_cost
1779        );
1780    }
1781
1782    #[test]
1783    fn test_no_stats_defaults_to_scan() {
1784        let config = CostModelConfig::default();
1785        let optimizer = CostBasedOptimizer::new(config);
1786        // No stats loaded — optimizer should still work with defaults
1787        let plan = optimizer.optimize(
1788            "unknown_table",
1789            vec!["col1".to_string()],
1790            Some(Predicate::Eq {
1791                column: "col1".to_string(),
1792                value: "x".to_string(),
1793            }),
1794            vec![],
1795            None,
1796        );
1797        // Should produce a valid plan (TableScan with default estimates)
1798        match plan {
1799            PhysicalPlan::TableScan { estimated_rows, .. } => {
1800                assert!(estimated_rows > 0, "Default row estimate must be positive");
1801            }
1802            PhysicalPlan::IndexSeek { .. } => {} // also fine with no stats
1803            _ => panic!("Expected TableScan or IndexSeek for unknown table"),
1804        }
1805    }
1806
1807    #[test]
1808    fn test_compound_predicate_selectivity() {
1809        let stats = create_test_stats();
1810        let config = CostModelConfig::default();
1811        let optimizer = CostBasedOptimizer::new(config);
1812
1813        // AND: independent → multiply
1814        let and_pred = Predicate::And(
1815            Box::new(Predicate::Eq {
1816                column: "id".to_string(),
1817                value: "1".to_string(),
1818            }),
1819            Box::new(Predicate::IsNotNull {
1820                column: "score".to_string(),
1821            }),
1822        );
1823        let sel = optimizer.estimate_selectivity(&and_pred, &stats);
1824        let eq_sel = optimizer.estimate_selectivity(
1825            &Predicate::Eq {
1826                column: "id".to_string(),
1827                value: "1".to_string(),
1828            },
1829            &stats,
1830        );
1831        assert!(sel < eq_sel, "AND must be more selective than either child");
1832
1833        // OR: P(A∪B) = P(A)+P(B)-P(A∩B)
1834        let or_pred = Predicate::Or(
1835            Box::new(Predicate::Eq {
1836                column: "id".to_string(),
1837                value: "1".to_string(),
1838            }),
1839            Box::new(Predicate::Eq {
1840                column: "id".to_string(),
1841                value: "2".to_string(),
1842            }),
1843        );
1844        let sel = optimizer.estimate_selectivity(&or_pred, &stats);
1845        assert!(sel > eq_sel, "OR must be less selective than either child");
1846        assert!(sel <= 1.0, "Selectivity must be <= 1.0");
1847    }
1848
1849    #[test]
1850    fn test_join_order_optimizer() {
1851        let mut join_opt = JoinOrderOptimizer::new(CostModelConfig::default());
1852        join_opt.add_stats(TableStats {
1853            name: "orders".to_string(),
1854            row_count: 1000000,
1855            size_bytes: 100_000_000,
1856            column_stats: HashMap::new(),
1857            indices: vec![],
1858            last_updated: 0,
1859        });
1860        join_opt.add_stats(TableStats {
1861            name: "users".to_string(),
1862            row_count: 10000,
1863            size_bytes: 1_000_000,
1864            column_stats: HashMap::new(),
1865            indices: vec![],
1866            last_updated: 0,
1867        });
1868
1869        let order = join_opt.find_optimal_order(
1870            &["orders".to_string(), "users".to_string()],
1871            &[(
1872                "orders".to_string(),
1873                "user_id".to_string(),
1874                "users".to_string(),
1875                "id".to_string(),
1876            )],
1877        );
1878        assert!(!order.is_empty(), "Should find a join order");
1879    }
1880}