sochdb_query/
cost_optimizer.rs

1// Copyright 2025 Sushanth (https://github.com/sushanthpy)
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! Cost-Based Query Optimizer with Cardinality Estimation (Task 6)
16//!
17//! Provides cost-based query optimization for SOCH-QL with:
18//! - Cardinality estimation using sketches (HyperLogLog, CountMin)
19//! - Index selection: compare cost(table_scan) vs cost(index_seek)
20//! - Column projection pushdown to LSCS layer
21//! - Token-budget-aware planning
22//!
23//! ## Cost Model
24//!
25//! cost(plan) = I/O_cost + CPU_cost + memory_cost
26//!
27//! I/O_cost = blocks_read × C_seq + seeks × C_random
28//! Where:
29//!   C_seq = 0.1 ms/block (sequential read)
30//!   C_random = 5 ms/seek (random seek)
31//!
32//! CPU_cost = rows_processed × C_filter + sorts × N × log(N) × C_compare
33//!
34//! ## Selectivity Estimation
35//!
36//! Uses CountMinSketch for predicate selectivity and HyperLogLog for distinct counts.
37//!
38//! ## Token Budget Planning
39//!
40//! Given max_tokens, estimates result size and injects LIMIT clause:
41//!   max_rows = (max_tokens - header_tokens) / tokens_per_row
42
43use parking_lot::RwLock;
44use std::collections::{HashMap, HashSet};
45use std::sync::Arc;
46
47// ============================================================================
48// Cost Model Constants
49// ============================================================================
50
51/// Cost model configuration with empirically-derived constants
52#[derive(Debug, Clone)]
53pub struct CostModelConfig {
54    /// Sequential I/O cost per block (ms)
55    pub c_seq: f64,
56    /// Random I/O cost per seek (ms)
57    pub c_random: f64,
58    /// CPU cost per row filter (ms)
59    pub c_filter: f64,
60    /// CPU cost per comparison during sort (ms)
61    pub c_compare: f64,
62    /// Block size in bytes
63    pub block_size: usize,
64    /// B-tree fanout for index cost estimation
65    pub btree_fanout: usize,
66    /// Memory bandwidth (bytes/ms)
67    pub memory_bandwidth: f64,
68}
69
70impl Default for CostModelConfig {
71    fn default() -> Self {
72        Self {
73            c_seq: 0.1,                // 0.1 ms per block sequential
74            c_random: 5.0,             // 5 ms per random seek
75            c_filter: 0.001,           // 0.001 ms per row filter
76            c_compare: 0.0001,         // 0.0001 ms per comparison
77            block_size: 4096,          // 4 KB blocks
78            btree_fanout: 100,         // 100 entries per B-tree node
79            memory_bandwidth: 10000.0, // 10 GB/s = 10000 bytes/ms
80        }
81    }
82}
83
84// ============================================================================
85// Statistics for Cardinality Estimation
86// ============================================================================
87
88/// Table statistics for cost estimation
89#[derive(Debug, Clone)]
90pub struct TableStats {
91    /// Table name
92    pub name: String,
93    /// Total row count
94    pub row_count: u64,
95    /// Total size in bytes
96    pub size_bytes: u64,
97    /// Column statistics
98    pub column_stats: HashMap<String, ColumnStats>,
99    /// Available indices
100    pub indices: Vec<IndexStats>,
101    /// Last update timestamp
102    pub last_updated: u64,
103}
104
105/// Column statistics
106#[derive(Debug, Clone)]
107pub struct ColumnStats {
108    /// Column name
109    pub name: String,
110    /// Distinct value count (from HyperLogLog)
111    pub distinct_count: u64,
112    /// Null count
113    pub null_count: u64,
114    /// Minimum value (if orderable)
115    pub min_value: Option<String>,
116    /// Maximum value (if orderable)
117    pub max_value: Option<String>,
118    /// Average length in bytes (for variable-length types)
119    pub avg_length: f64,
120    /// Most common values with frequencies
121    pub mcv: Vec<(String, f64)>,
122    /// Histogram buckets for range queries
123    pub histogram: Option<Histogram>,
124}
125
126/// Histogram for range selectivity estimation
127#[derive(Debug, Clone)]
128pub struct Histogram {
129    /// Bucket boundaries
130    pub boundaries: Vec<f64>,
131    /// Row count per bucket
132    pub counts: Vec<u64>,
133    /// Total rows in histogram
134    pub total_rows: u64,
135}
136
137impl Histogram {
138    /// Estimate selectivity for a range predicate
139    pub fn estimate_range_selectivity(&self, min: Option<f64>, max: Option<f64>) -> f64 {
140        if self.total_rows == 0 {
141            return 0.5; // Default
142        }
143
144        let mut selected_rows = 0u64;
145
146        for (i, &count) in self.counts.iter().enumerate() {
147            let bucket_min = if i == 0 {
148                f64::NEG_INFINITY
149            } else {
150                self.boundaries[i - 1]
151            };
152            let bucket_max = if i == self.boundaries.len() {
153                f64::INFINITY
154            } else {
155                self.boundaries[i]
156            };
157
158            let overlaps = match (min, max) {
159                (Some(min_val), Some(max_val)) => bucket_max >= min_val && bucket_min <= max_val,
160                (Some(min_val), None) => bucket_max >= min_val,
161                (None, Some(max_val)) => bucket_min <= max_val,
162                (None, None) => true,
163            };
164
165            if overlaps {
166                selected_rows += count;
167            }
168        }
169
170        selected_rows as f64 / self.total_rows as f64
171    }
172}
173
174/// Index statistics
175#[derive(Debug, Clone)]
176pub struct IndexStats {
177    /// Index name
178    pub name: String,
179    /// Indexed columns
180    pub columns: Vec<String>,
181    /// Is primary key
182    pub is_primary: bool,
183    /// Is unique
184    pub is_unique: bool,
185    /// Index type
186    pub index_type: IndexType,
187    /// Number of leaf pages
188    pub leaf_pages: u64,
189    /// Tree height (for B-tree)
190    pub height: u32,
191    /// Average entries per leaf page
192    pub avg_leaf_density: f64,
193}
194
195/// Index types
196#[derive(Debug, Clone, Copy, PartialEq, Eq)]
197pub enum IndexType {
198    BTree,
199    Hash,
200    LSM,
201    Learned,
202    Vector,
203    Bloom,
204}
205
206// ============================================================================
207// Query Predicates and Operations
208// ============================================================================
209
210/// Query predicate for cost estimation
211#[derive(Debug, Clone)]
212pub enum Predicate {
213    /// Equality: column = value
214    Eq { column: String, value: String },
215    /// Inequality: column != value
216    Ne { column: String, value: String },
217    /// Less than: column < value
218    Lt { column: String, value: String },
219    /// Less than or equal: column <= value
220    Le { column: String, value: String },
221    /// Greater than: column > value
222    Gt { column: String, value: String },
223    /// Greater than or equal: column >= value
224    Ge { column: String, value: String },
225    /// Between: column BETWEEN min AND max
226    Between {
227        column: String,
228        min: String,
229        max: String,
230    },
231    /// In list: column IN (v1, v2, ...)
232    In { column: String, values: Vec<String> },
233    /// Like: column LIKE pattern
234    Like { column: String, pattern: String },
235    /// Is null: column IS NULL
236    IsNull { column: String },
237    /// Is not null: column IS NOT NULL
238    IsNotNull { column: String },
239    /// And: pred1 AND pred2
240    And(Box<Predicate>, Box<Predicate>),
241    /// Or: pred1 OR pred2
242    Or(Box<Predicate>, Box<Predicate>),
243    /// Not: NOT pred
244    Not(Box<Predicate>),
245}
246
247impl Predicate {
248    /// Get columns referenced by this predicate
249    pub fn referenced_columns(&self) -> HashSet<String> {
250        let mut cols = HashSet::new();
251        self.collect_columns(&mut cols);
252        cols
253    }
254
255    fn collect_columns(&self, cols: &mut HashSet<String>) {
256        match self {
257            Self::Eq { column, .. }
258            | Self::Ne { column, .. }
259            | Self::Lt { column, .. }
260            | Self::Le { column, .. }
261            | Self::Gt { column, .. }
262            | Self::Ge { column, .. }
263            | Self::Between { column, .. }
264            | Self::In { column, .. }
265            | Self::Like { column, .. }
266            | Self::IsNull { column }
267            | Self::IsNotNull { column } => {
268                cols.insert(column.clone());
269            }
270            Self::And(left, right) | Self::Or(left, right) => {
271                left.collect_columns(cols);
272                right.collect_columns(cols);
273            }
274            Self::Not(inner) => inner.collect_columns(cols),
275        }
276    }
277}
278
279// ============================================================================
280// Physical Plan Operators
281// ============================================================================
282
283/// Physical query plan node
284#[derive(Debug, Clone)]
285pub enum PhysicalPlan {
286    /// Table scan (full or partial)
287    TableScan {
288        table: String,
289        columns: Vec<String>,
290        predicate: Option<Box<Predicate>>,
291        estimated_rows: u64,
292        estimated_cost: f64,
293    },
294    /// Index seek
295    IndexSeek {
296        table: String,
297        index: String,
298        columns: Vec<String>,
299        key_range: KeyRange,
300        predicate: Option<Box<Predicate>>,
301        estimated_rows: u64,
302        estimated_cost: f64,
303    },
304    /// Filter operator
305    Filter {
306        input: Box<PhysicalPlan>,
307        predicate: Predicate,
308        estimated_rows: u64,
309        estimated_cost: f64,
310    },
311    /// Project operator (column subset)
312    Project {
313        input: Box<PhysicalPlan>,
314        columns: Vec<String>,
315        estimated_cost: f64,
316    },
317    /// Sort operator
318    Sort {
319        input: Box<PhysicalPlan>,
320        order_by: Vec<(String, SortDirection)>,
321        estimated_cost: f64,
322    },
323    /// Limit operator
324    Limit {
325        input: Box<PhysicalPlan>,
326        limit: u64,
327        offset: u64,
328        estimated_cost: f64,
329    },
330    /// Nested loop join
331    NestedLoopJoin {
332        outer: Box<PhysicalPlan>,
333        inner: Box<PhysicalPlan>,
334        condition: Predicate,
335        join_type: JoinType,
336        estimated_rows: u64,
337        estimated_cost: f64,
338    },
339    /// Hash join
340    HashJoin {
341        build: Box<PhysicalPlan>,
342        probe: Box<PhysicalPlan>,
343        build_keys: Vec<String>,
344        probe_keys: Vec<String>,
345        join_type: JoinType,
346        estimated_rows: u64,
347        estimated_cost: f64,
348    },
349    /// Merge join
350    MergeJoin {
351        left: Box<PhysicalPlan>,
352        right: Box<PhysicalPlan>,
353        left_keys: Vec<String>,
354        right_keys: Vec<String>,
355        join_type: JoinType,
356        estimated_rows: u64,
357        estimated_cost: f64,
358    },
359    /// Aggregate operator
360    Aggregate {
361        input: Box<PhysicalPlan>,
362        group_by: Vec<String>,
363        aggregates: Vec<AggregateExpr>,
364        estimated_rows: u64,
365        estimated_cost: f64,
366    },
367}
368
369/// Key range for index seeks
370#[derive(Debug, Clone)]
371pub struct KeyRange {
372    pub start: Option<Vec<u8>>,
373    pub end: Option<Vec<u8>>,
374    pub start_inclusive: bool,
375    pub end_inclusive: bool,
376}
377
378impl KeyRange {
379    pub fn all() -> Self {
380        Self {
381            start: None,
382            end: None,
383            start_inclusive: true,
384            end_inclusive: true,
385        }
386    }
387
388    pub fn point(key: Vec<u8>) -> Self {
389        Self {
390            start: Some(key.clone()),
391            end: Some(key),
392            start_inclusive: true,
393            end_inclusive: true,
394        }
395    }
396
397    pub fn range(start: Option<Vec<u8>>, end: Option<Vec<u8>>, inclusive: bool) -> Self {
398        Self {
399            start,
400            end,
401            start_inclusive: inclusive,
402            end_inclusive: inclusive,
403        }
404    }
405}
406
407/// Sort direction
408#[derive(Debug, Clone, Copy, PartialEq, Eq)]
409pub enum SortDirection {
410    Ascending,
411    Descending,
412}
413
414/// Join type
415#[derive(Debug, Clone, Copy, PartialEq, Eq)]
416pub enum JoinType {
417    Inner,
418    Left,
419    Right,
420    Full,
421    Cross,
422}
423
424/// Aggregate expression
425#[derive(Debug, Clone)]
426pub struct AggregateExpr {
427    pub function: AggregateFunction,
428    pub column: Option<String>,
429    pub alias: String,
430}
431
432/// Aggregate functions
433#[derive(Debug, Clone, Copy, PartialEq, Eq)]
434pub enum AggregateFunction {
435    Count,
436    Sum,
437    Avg,
438    Min,
439    Max,
440    CountDistinct,
441}
442
443// ============================================================================
444// Cost-Based Query Optimizer
445// ============================================================================
446
447/// Cost-based query optimizer
448pub struct CostBasedOptimizer {
449    /// Cost model configuration
450    config: CostModelConfig,
451    /// Table statistics cache
452    stats_cache: Arc<RwLock<HashMap<String, TableStats>>>,
453    /// Token budget for result limiting
454    token_budget: Option<u64>,
455    /// Estimated tokens per row
456    tokens_per_row: f64,
457}
458
459impl CostBasedOptimizer {
460    pub fn new(config: CostModelConfig) -> Self {
461        Self {
462            config,
463            stats_cache: Arc::new(RwLock::new(HashMap::new())),
464            token_budget: None,
465            tokens_per_row: 25.0, // Default estimate
466        }
467    }
468
469    /// Set token budget for result limiting
470    pub fn with_token_budget(mut self, budget: u64, tokens_per_row: f64) -> Self {
471        self.token_budget = Some(budget);
472        self.tokens_per_row = tokens_per_row;
473        self
474    }
475
476    /// Update table statistics
477    pub fn update_stats(&self, stats: TableStats) {
478        self.stats_cache.write().insert(stats.name.clone(), stats);
479    }
480
481    /// Get table statistics
482    pub fn get_stats(&self, table: &str) -> Option<TableStats> {
483        self.stats_cache.read().get(table).cloned()
484    }
485
486    /// Optimize a SELECT query
487    pub fn optimize(
488        &self,
489        table: &str,
490        columns: Vec<String>,
491        predicate: Option<Predicate>,
492        order_by: Vec<(String, SortDirection)>,
493        limit: Option<u64>,
494    ) -> PhysicalPlan {
495        let stats = self.get_stats(table);
496
497        // Calculate token-aware limit
498        let effective_limit = self.calculate_token_limit(limit);
499
500        // Get best access path (scan vs index)
501        let mut plan = self.choose_access_path(table, &columns, predicate.as_ref(), &stats);
502
503        // Apply column projection pushdown
504        plan = self.apply_projection_pushdown(plan, columns.clone());
505
506        // Apply sorting if needed
507        if !order_by.is_empty() {
508            plan = self.add_sort(plan, order_by, &stats);
509        }
510
511        // Apply limit
512        if let Some(lim) = effective_limit {
513            plan = PhysicalPlan::Limit {
514                estimated_cost: 0.0,
515                input: Box::new(plan),
516                limit: lim,
517                offset: 0,
518            };
519        }
520
521        plan
522    }
523
524    /// Calculate token-aware limit
525    fn calculate_token_limit(&self, user_limit: Option<u64>) -> Option<u64> {
526        match (self.token_budget, user_limit) {
527            (Some(budget), Some(limit)) => {
528                let header_tokens = 50u64;
529                let max_rows = ((budget - header_tokens) as f64 / self.tokens_per_row) as u64;
530                Some(limit.min(max_rows))
531            }
532            (Some(budget), None) => {
533                let header_tokens = 50u64;
534                let max_rows = ((budget - header_tokens) as f64 / self.tokens_per_row) as u64;
535                Some(max_rows)
536            }
537            (None, limit) => limit,
538        }
539    }
540
541    /// Choose best access path (table scan vs index seek)
542    fn choose_access_path(
543        &self,
544        table: &str,
545        columns: &[String],
546        predicate: Option<&Predicate>,
547        stats: &Option<TableStats>,
548    ) -> PhysicalPlan {
549        let row_count = stats.as_ref().map(|s| s.row_count).unwrap_or(10000);
550        let size_bytes = stats
551            .as_ref()
552            .map(|s| s.size_bytes)
553            .unwrap_or(row_count * 100);
554
555        // Calculate table scan cost
556        let scan_cost = self.estimate_scan_cost(row_count, size_bytes, predicate);
557
558        // Try to find a suitable index
559        let mut best_index_cost = f64::MAX;
560        let mut best_index: Option<&IndexStats> = None;
561
562        if let Some(table_stats) = stats.as_ref()
563            && let Some(pred) = predicate
564        {
565            let pred_columns = pred.referenced_columns();
566
567            for index in &table_stats.indices {
568                if self.index_covers_predicate(index, &pred_columns) {
569                    let selectivity = self.estimate_selectivity(pred, table_stats);
570                    let index_cost = self.estimate_index_cost(index, row_count, selectivity);
571
572                    if index_cost < best_index_cost {
573                        best_index_cost = index_cost;
574                        best_index = Some(index);
575                    }
576                }
577            }
578        }
579
580        // Choose cheaper option
581        if best_index_cost < scan_cost {
582            let index = best_index.unwrap();
583            let selectivity = predicate
584                .map(|p| self.estimate_selectivity(p, stats.as_ref().unwrap()))
585                .unwrap_or(1.0);
586
587            PhysicalPlan::IndexSeek {
588                table: table.to_string(),
589                index: index.name.clone(),
590                columns: columns.to_vec(),
591                key_range: KeyRange::all(), // Simplified
592                predicate: predicate.map(|p| Box::new(p.clone())),
593                estimated_rows: (row_count as f64 * selectivity) as u64,
594                estimated_cost: best_index_cost,
595            }
596        } else {
597            PhysicalPlan::TableScan {
598                table: table.to_string(),
599                columns: columns.to_vec(),
600                predicate: predicate.map(|p| Box::new(p.clone())),
601                estimated_rows: row_count,
602                estimated_cost: scan_cost,
603            }
604        }
605    }
606
607    /// Check if index covers predicate columns
608    fn index_covers_predicate(&self, index: &IndexStats, pred_columns: &HashSet<String>) -> bool {
609        // Index is useful if it covers at least the first column of the predicate
610        if let Some(first_col) = index.columns.first() {
611            pred_columns.contains(first_col)
612        } else {
613            false
614        }
615    }
616
617    /// Estimate table scan cost
618    fn estimate_scan_cost(
619        &self,
620        row_count: u64,
621        size_bytes: u64,
622        predicate: Option<&Predicate>,
623    ) -> f64 {
624        let blocks = (size_bytes as f64 / self.config.block_size as f64).ceil() as u64;
625
626        // I/O cost: sequential read
627        let io_cost = blocks as f64 * self.config.c_seq;
628
629        // CPU cost: filter all rows
630        let selectivity = predicate.map(|_| 0.1).unwrap_or(1.0);
631        let cpu_cost = row_count as f64 * self.config.c_filter * selectivity;
632
633        io_cost + cpu_cost
634    }
635
636    /// Estimate index seek cost
637    ///
638    /// Index cost = tree_traversal + leaf_scan + row_fetch
639    fn estimate_index_cost(&self, index: &IndexStats, total_rows: u64, selectivity: f64) -> f64 {
640        // Tree traversal cost (random I/O for each level)
641        let tree_cost = index.height as f64 * self.config.c_random;
642
643        // Leaf scan cost (sequential for matching range)
644        let matching_rows = (total_rows as f64 * selectivity) as u64;
645        let leaf_pages_scanned = (matching_rows as f64 / index.avg_leaf_density).ceil() as u64;
646        let leaf_cost = leaf_pages_scanned as f64 * self.config.c_seq;
647
648        // Row fetch cost (random if not clustered)
649        let fetch_cost = if index.is_primary {
650            0.0 // Clustered index, no extra fetch
651        } else {
652            matching_rows.min(1000) as f64 * self.config.c_random * 0.1 // Batch optimization
653        };
654
655        tree_cost + leaf_cost + fetch_cost
656    }
657
658    /// Estimate predicate selectivity
659    #[allow(clippy::only_used_in_recursion)]
660    fn estimate_selectivity(&self, predicate: &Predicate, stats: &TableStats) -> f64 {
661        match predicate {
662            Predicate::Eq { column, value } => {
663                if let Some(col_stats) = stats.column_stats.get(column) {
664                    // Check MCV first
665                    for (mcv_val, freq) in &col_stats.mcv {
666                        if mcv_val == value {
667                            return *freq;
668                        }
669                    }
670                    // Otherwise use uniform distribution
671                    1.0 / col_stats.distinct_count.max(1) as f64
672                } else {
673                    0.1 // Default 10%
674                }
675            }
676            Predicate::Ne { .. } => 0.9, // 90% pass
677            Predicate::Lt { column, value }
678            | Predicate::Le { column, value }
679            | Predicate::Gt { column, value }
680            | Predicate::Ge { column, value } => {
681                if let Some(col_stats) = stats.column_stats.get(column) {
682                    if let Some(ref hist) = col_stats.histogram {
683                        let val: f64 = value.parse().unwrap_or(0.0);
684                        match predicate {
685                            Predicate::Lt { .. } | Predicate::Le { .. } => {
686                                hist.estimate_range_selectivity(None, Some(val))
687                            }
688                            _ => hist.estimate_range_selectivity(Some(val), None),
689                        }
690                    } else {
691                        0.25 // Default 25%
692                    }
693                } else {
694                    0.25
695                }
696            }
697            Predicate::Between { column, min, max } => {
698                if let Some(col_stats) = stats.column_stats.get(column) {
699                    if let Some(ref hist) = col_stats.histogram {
700                        let min_val: f64 = min.parse().unwrap_or(0.0);
701                        let max_val: f64 = max.parse().unwrap_or(f64::MAX);
702                        hist.estimate_range_selectivity(Some(min_val), Some(max_val))
703                    } else {
704                        0.2
705                    }
706                } else {
707                    0.2
708                }
709            }
710            Predicate::In { column, values } => {
711                if let Some(col_stats) = stats.column_stats.get(column) {
712                    (values.len() as f64 / col_stats.distinct_count.max(1) as f64).min(1.0)
713                } else {
714                    (values.len() as f64 * 0.1).min(0.5)
715                }
716            }
717            Predicate::Like { .. } => 0.15, // Default 15%
718            Predicate::IsNull { column } => {
719                if let Some(col_stats) = stats.column_stats.get(column) {
720                    col_stats.null_count as f64 / stats.row_count.max(1) as f64
721                } else {
722                    0.01
723                }
724            }
725            Predicate::IsNotNull { column } => {
726                if let Some(col_stats) = stats.column_stats.get(column) {
727                    1.0 - (col_stats.null_count as f64 / stats.row_count.max(1) as f64)
728                } else {
729                    0.99
730                }
731            }
732            Predicate::And(left, right) => {
733                // Assume independence
734                self.estimate_selectivity(left, stats) * self.estimate_selectivity(right, stats)
735            }
736            Predicate::Or(left, right) => {
737                let s1 = self.estimate_selectivity(left, stats);
738                let s2 = self.estimate_selectivity(right, stats);
739                // P(A or B) = P(A) + P(B) - P(A and B)
740                (s1 + s2 - s1 * s2).min(1.0)
741            }
742            Predicate::Not(inner) => 1.0 - self.estimate_selectivity(inner, stats),
743        }
744    }
745
746    /// Apply column projection pushdown
747    fn apply_projection_pushdown(&self, plan: PhysicalPlan, columns: Vec<String>) -> PhysicalPlan {
748        // If plan already has projection, merge; otherwise add
749        match plan {
750            PhysicalPlan::TableScan {
751                table,
752                predicate,
753                estimated_rows,
754                estimated_cost,
755                ..
756            } => {
757                PhysicalPlan::TableScan {
758                    table,
759                    columns, // Pushed down columns
760                    predicate,
761                    estimated_rows,
762                    estimated_cost: estimated_cost * 0.2, // Reduce cost estimate
763                }
764            }
765            PhysicalPlan::IndexSeek {
766                table,
767                index,
768                key_range,
769                predicate,
770                estimated_rows,
771                estimated_cost,
772                ..
773            } => {
774                PhysicalPlan::IndexSeek {
775                    table,
776                    index,
777                    columns, // Pushed down columns
778                    key_range,
779                    predicate,
780                    estimated_rows,
781                    estimated_cost,
782                }
783            }
784            other => PhysicalPlan::Project {
785                input: Box::new(other),
786                columns,
787                estimated_cost: 0.0,
788            },
789        }
790    }
791
792    /// Add sort operator
793    fn add_sort(
794        &self,
795        plan: PhysicalPlan,
796        order_by: Vec<(String, SortDirection)>,
797        _stats: &Option<TableStats>,
798    ) -> PhysicalPlan {
799        let estimated_rows = self.get_plan_rows(&plan);
800        let sort_cost = if estimated_rows > 0 {
801            estimated_rows as f64 * (estimated_rows as f64).log2() * self.config.c_compare
802        } else {
803            0.0
804        };
805
806        PhysicalPlan::Sort {
807            input: Box::new(plan),
808            order_by,
809            estimated_cost: sort_cost,
810        }
811    }
812
813    /// Get estimated rows from a plan
814    #[allow(clippy::only_used_in_recursion)]
815    fn get_plan_rows(&self, plan: &PhysicalPlan) -> u64 {
816        match plan {
817            PhysicalPlan::TableScan { estimated_rows, .. }
818            | PhysicalPlan::IndexSeek { estimated_rows, .. }
819            | PhysicalPlan::Filter { estimated_rows, .. }
820            | PhysicalPlan::Aggregate { estimated_rows, .. }
821            | PhysicalPlan::NestedLoopJoin { estimated_rows, .. }
822            | PhysicalPlan::HashJoin { estimated_rows, .. }
823            | PhysicalPlan::MergeJoin { estimated_rows, .. } => *estimated_rows,
824            PhysicalPlan::Project { input, .. } | PhysicalPlan::Sort { input, .. } => {
825                self.get_plan_rows(input)
826            }
827            PhysicalPlan::Limit { limit, .. } => *limit,
828        }
829    }
830
831    /// Get estimated cost from a plan
832    #[allow(clippy::only_used_in_recursion)]
833    pub fn get_plan_cost(&self, plan: &PhysicalPlan) -> f64 {
834        match plan {
835            PhysicalPlan::TableScan { estimated_cost, .. } => *estimated_cost,
836            PhysicalPlan::IndexSeek { estimated_cost, .. } => *estimated_cost,
837            PhysicalPlan::Filter {
838                estimated_cost,
839                input,
840                ..
841            } => *estimated_cost + self.get_plan_cost(input),
842            PhysicalPlan::Project {
843                estimated_cost,
844                input,
845                ..
846            } => *estimated_cost + self.get_plan_cost(input),
847            PhysicalPlan::Sort {
848                estimated_cost,
849                input,
850                ..
851            } => *estimated_cost + self.get_plan_cost(input),
852            PhysicalPlan::Limit {
853                estimated_cost,
854                input,
855                ..
856            } => *estimated_cost + self.get_plan_cost(input),
857            PhysicalPlan::NestedLoopJoin {
858                estimated_cost,
859                outer,
860                inner,
861                ..
862            } => *estimated_cost + self.get_plan_cost(outer) + self.get_plan_cost(inner),
863            PhysicalPlan::HashJoin {
864                estimated_cost,
865                build,
866                probe,
867                ..
868            } => *estimated_cost + self.get_plan_cost(build) + self.get_plan_cost(probe),
869            PhysicalPlan::MergeJoin {
870                estimated_cost,
871                left,
872                right,
873                ..
874            } => *estimated_cost + self.get_plan_cost(left) + self.get_plan_cost(right),
875            PhysicalPlan::Aggregate {
876                estimated_cost,
877                input,
878                ..
879            } => *estimated_cost + self.get_plan_cost(input),
880        }
881    }
882
883    /// Generate EXPLAIN output
884    pub fn explain(&self, plan: &PhysicalPlan) -> String {
885        self.explain_impl(plan, 0)
886    }
887
888    fn explain_impl(&self, plan: &PhysicalPlan, indent: usize) -> String {
889        let prefix = "  ".repeat(indent);
890        let cost = self.get_plan_cost(plan);
891
892        match plan {
893            PhysicalPlan::TableScan {
894                table,
895                columns,
896                estimated_rows,
897                ..
898            } => {
899                format!(
900                    "{}TableScan [table={}, columns={:?}, rows={}, cost={:.2}ms]",
901                    prefix, table, columns, estimated_rows, cost
902                )
903            }
904            PhysicalPlan::IndexSeek {
905                table,
906                index,
907                columns,
908                estimated_rows,
909                ..
910            } => {
911                format!(
912                    "{}IndexSeek [table={}, index={}, columns={:?}, rows={}, cost={:.2}ms]",
913                    prefix, table, index, columns, estimated_rows, cost
914                )
915            }
916            PhysicalPlan::Filter {
917                input,
918                estimated_rows,
919                ..
920            } => {
921                format!(
922                    "{}Filter [rows={}, cost={:.2}ms]\n{}",
923                    prefix,
924                    estimated_rows,
925                    cost,
926                    self.explain_impl(input, indent + 1)
927                )
928            }
929            PhysicalPlan::Project { input, columns, .. } => {
930                format!(
931                    "{}Project [columns={:?}, cost={:.2}ms]\n{}",
932                    prefix,
933                    columns,
934                    cost,
935                    self.explain_impl(input, indent + 1)
936                )
937            }
938            PhysicalPlan::Sort {
939                input, order_by, ..
940            } => {
941                let order: Vec<_> = order_by
942                    .iter()
943                    .map(|(c, d)| format!("{} {:?}", c, d))
944                    .collect();
945                format!(
946                    "{}Sort [order={:?}, cost={:.2}ms]\n{}",
947                    prefix,
948                    order,
949                    cost,
950                    self.explain_impl(input, indent + 1)
951                )
952            }
953            PhysicalPlan::Limit {
954                input,
955                limit,
956                offset,
957                ..
958            } => {
959                format!(
960                    "{}Limit [limit={}, offset={}, cost={:.2}ms]\n{}",
961                    prefix,
962                    limit,
963                    offset,
964                    cost,
965                    self.explain_impl(input, indent + 1)
966                )
967            }
968            PhysicalPlan::HashJoin {
969                build,
970                probe,
971                join_type,
972                estimated_rows,
973                ..
974            } => {
975                format!(
976                    "{}HashJoin [type={:?}, rows={}, cost={:.2}ms]\n{}\n{}",
977                    prefix,
978                    join_type,
979                    estimated_rows,
980                    cost,
981                    self.explain_impl(build, indent + 1),
982                    self.explain_impl(probe, indent + 1)
983                )
984            }
985            PhysicalPlan::MergeJoin {
986                left,
987                right,
988                join_type,
989                estimated_rows,
990                ..
991            } => {
992                format!(
993                    "{}MergeJoin [type={:?}, rows={}, cost={:.2}ms]\n{}\n{}",
994                    prefix,
995                    join_type,
996                    estimated_rows,
997                    cost,
998                    self.explain_impl(left, indent + 1),
999                    self.explain_impl(right, indent + 1)
1000                )
1001            }
1002            PhysicalPlan::NestedLoopJoin {
1003                outer,
1004                inner,
1005                join_type,
1006                estimated_rows,
1007                ..
1008            } => {
1009                format!(
1010                    "{}NestedLoopJoin [type={:?}, rows={}, cost={:.2}ms]\n{}\n{}",
1011                    prefix,
1012                    join_type,
1013                    estimated_rows,
1014                    cost,
1015                    self.explain_impl(outer, indent + 1),
1016                    self.explain_impl(inner, indent + 1)
1017                )
1018            }
1019            PhysicalPlan::Aggregate {
1020                input,
1021                group_by,
1022                aggregates,
1023                estimated_rows,
1024                ..
1025            } => {
1026                let aggs: Vec<_> = aggregates
1027                    .iter()
1028                    .map(|a| format!("{:?}({})", a.function, a.column.as_deref().unwrap_or("*")))
1029                    .collect();
1030                format!(
1031                    "{}Aggregate [group_by={:?}, aggs={:?}, rows={}, cost={:.2}ms]\n{}",
1032                    prefix,
1033                    group_by,
1034                    aggs,
1035                    estimated_rows,
1036                    cost,
1037                    self.explain_impl(input, indent + 1)
1038                )
1039            }
1040        }
1041    }
1042}
1043
1044// ============================================================================
1045// Join Order Optimizer (Dynamic Programming)
1046// ============================================================================
1047
1048/// Join order optimizer using dynamic programming
1049pub struct JoinOrderOptimizer {
1050    /// Table statistics
1051    stats: HashMap<String, TableStats>,
1052    /// Cost model
1053    config: CostModelConfig,
1054}
1055
1056impl JoinOrderOptimizer {
1057    pub fn new(config: CostModelConfig) -> Self {
1058        Self {
1059            stats: HashMap::new(),
1060            config,
1061        }
1062    }
1063
1064    /// Add table statistics
1065    pub fn add_stats(&mut self, stats: TableStats) {
1066        self.stats.insert(stats.name.clone(), stats);
1067    }
1068
1069    /// Find optimal join order using dynamic programming
1070    ///
1071    /// Time: O(2^n × n^2) where n = number of tables
1072    /// Practical for n ≤ 10
1073    pub fn find_optimal_order(
1074        &self,
1075        tables: &[String],
1076        join_conditions: &[(String, String, String, String)], // (table1, col1, table2, col2)
1077    ) -> Vec<(String, String)> {
1078        let n = tables.len();
1079        if n <= 1 {
1080            return vec![];
1081        }
1082
1083        // dp[mask] = (cost, join_order)
1084        let mut dp: HashMap<u32, (f64, Vec<(String, String)>)> = HashMap::new();
1085
1086        // Base case: single tables
1087        for (i, _table) in tables.iter().enumerate() {
1088            let mask = 1u32 << i;
1089            dp.insert(mask, (0.0, vec![]));
1090        }
1091
1092        // Build up larger subsets
1093        for size in 2..=n {
1094            for mask in 0..(1u32 << n) {
1095                if mask.count_ones() != size as u32 {
1096                    continue;
1097                }
1098
1099                let mut best_cost = f64::MAX;
1100                let mut best_order = vec![];
1101
1102                // Try all ways to split into two non-empty subsets
1103                for sub in 1..mask {
1104                    if sub & mask != sub || sub == 0 {
1105                        continue;
1106                    }
1107                    let other = mask ^ sub;
1108                    if other == 0 {
1109                        continue;
1110                    }
1111
1112                    // Check if there's a join between sub and other
1113                    if !self.has_join_condition(tables, sub, other, join_conditions) {
1114                        continue;
1115                    }
1116
1117                    if let (Some((cost1, order1)), Some((cost2, order2))) =
1118                        (dp.get(&sub), dp.get(&other))
1119                    {
1120                        let join_cost = self.estimate_join_cost(tables, sub, other);
1121                        let total_cost = cost1 + cost2 + join_cost;
1122
1123                        if total_cost < best_cost {
1124                            best_cost = total_cost;
1125                            best_order = order1.clone();
1126                            best_order.extend(order2.clone());
1127
1128                            // Add the join
1129                            let (t1, t2) =
1130                                self.get_join_tables(tables, sub, other, join_conditions);
1131                            if let Some((t1, t2)) = Some((t1, t2)) {
1132                                best_order.push((t1, t2));
1133                            }
1134                        }
1135                    }
1136                }
1137
1138                if best_cost < f64::MAX {
1139                    dp.insert(mask, (best_cost, best_order));
1140                }
1141            }
1142        }
1143
1144        let full_mask = (1u32 << n) - 1;
1145        dp.get(&full_mask)
1146            .map(|(_, order)| order.clone())
1147            .unwrap_or_default()
1148    }
1149
1150    fn has_join_condition(
1151        &self,
1152        tables: &[String],
1153        mask1: u32,
1154        mask2: u32,
1155        conditions: &[(String, String, String, String)],
1156    ) -> bool {
1157        for (t1, _, t2, _) in conditions {
1158            let idx1 = tables.iter().position(|t| t == t1);
1159            let idx2 = tables.iter().position(|t| t == t2);
1160
1161            if let (Some(i1), Some(i2)) = (idx1, idx2) {
1162                let in_mask1 = (mask1 >> i1) & 1 == 1;
1163                let in_mask2 = (mask2 >> i2) & 1 == 1;
1164
1165                if in_mask1 && in_mask2 {
1166                    return true;
1167                }
1168            }
1169        }
1170        false
1171    }
1172
1173    fn get_join_tables(
1174        &self,
1175        tables: &[String],
1176        mask1: u32,
1177        mask2: u32,
1178        conditions: &[(String, String, String, String)],
1179    ) -> (String, String) {
1180        for (t1, _, t2, _) in conditions {
1181            let idx1 = tables.iter().position(|t| t == t1);
1182            let idx2 = tables.iter().position(|t| t == t2);
1183
1184            if let (Some(i1), Some(i2)) = (idx1, idx2) {
1185                let t1_in_mask1 = (mask1 >> i1) & 1 == 1;
1186                let t2_in_mask2 = (mask2 >> i2) & 1 == 1;
1187
1188                if t1_in_mask1 && t2_in_mask2 {
1189                    return (t1.clone(), t2.clone());
1190                }
1191            }
1192        }
1193        (String::new(), String::new())
1194    }
1195
1196    fn estimate_join_cost(&self, tables: &[String], mask1: u32, mask2: u32) -> f64 {
1197        let rows1 = self.estimate_rows_for_mask(tables, mask1);
1198        let rows2 = self.estimate_rows_for_mask(tables, mask2);
1199
1200        // Hash join cost estimate
1201        // Build cost + probe cost
1202        let build_cost = rows1 as f64 * self.config.c_filter;
1203        let probe_cost = rows2 as f64 * self.config.c_filter;
1204
1205        build_cost + probe_cost
1206    }
1207
1208    fn estimate_rows_for_mask(&self, tables: &[String], mask: u32) -> u64 {
1209        let mut total = 1u64;
1210
1211        for (i, table) in tables.iter().enumerate() {
1212            if (mask >> i) & 1 == 1 {
1213                let rows = self.stats.get(table).map(|s| s.row_count).unwrap_or(1000);
1214                total = total.saturating_mul(rows);
1215            }
1216        }
1217
1218        // Apply default selectivity for joins
1219        let num_tables = mask.count_ones();
1220        if num_tables > 1 {
1221            total = (total as f64 * 0.1f64.powi(num_tables as i32 - 1)) as u64;
1222        }
1223
1224        total.max(1)
1225    }
1226}
1227
1228// ============================================================================
1229// Tests
1230// ============================================================================
1231
1232#[cfg(test)]
1233mod tests {
1234    use super::*;
1235
1236    fn create_test_stats() -> TableStats {
1237        let mut column_stats = HashMap::new();
1238        column_stats.insert(
1239            "id".to_string(),
1240            ColumnStats {
1241                name: "id".to_string(),
1242                distinct_count: 100000,
1243                null_count: 0,
1244                min_value: Some("1".to_string()),
1245                max_value: Some("100000".to_string()),
1246                avg_length: 8.0,
1247                mcv: vec![],
1248                histogram: None,
1249            },
1250        );
1251        column_stats.insert(
1252            "score".to_string(),
1253            ColumnStats {
1254                name: "score".to_string(),
1255                distinct_count: 100,
1256                null_count: 1000,
1257                min_value: Some("0".to_string()),
1258                max_value: Some("100".to_string()),
1259                avg_length: 8.0,
1260                mcv: vec![("50".to_string(), 0.05)],
1261                histogram: Some(Histogram {
1262                    boundaries: vec![25.0, 50.0, 75.0, 100.0],
1263                    counts: vec![25000, 25000, 25000, 25000],
1264                    total_rows: 100000,
1265                }),
1266            },
1267        );
1268
1269        TableStats {
1270            name: "users".to_string(),
1271            row_count: 100000,
1272            size_bytes: 10_000_000, // 10 MB
1273            column_stats,
1274            indices: vec![
1275                IndexStats {
1276                    name: "pk_users".to_string(),
1277                    columns: vec!["id".to_string()],
1278                    is_primary: true,
1279                    is_unique: true,
1280                    index_type: IndexType::BTree,
1281                    leaf_pages: 1000,
1282                    height: 3,
1283                    avg_leaf_density: 100.0,
1284                },
1285                IndexStats {
1286                    name: "idx_score".to_string(),
1287                    columns: vec!["score".to_string()],
1288                    is_primary: false,
1289                    is_unique: false,
1290                    index_type: IndexType::BTree,
1291                    leaf_pages: 500,
1292                    height: 2,
1293                    avg_leaf_density: 200.0,
1294                },
1295            ],
1296            last_updated: 0,
1297        }
1298    }
1299
1300    #[test]
1301    fn test_selectivity_estimation() {
1302        let config = CostModelConfig::default();
1303        let optimizer = CostBasedOptimizer::new(config);
1304
1305        let stats = create_test_stats();
1306        optimizer.update_stats(stats.clone());
1307
1308        // Equality predicate
1309        let pred = Predicate::Eq {
1310            column: "id".to_string(),
1311            value: "12345".to_string(),
1312        };
1313        let sel = optimizer.estimate_selectivity(&pred, &stats);
1314        assert!(sel < 0.001); // Should be very selective
1315
1316        // Range predicate with histogram
1317        // Note: For histogram boundaries [25, 50, 75, 100] with equal distribution,
1318        // Gt{75} includes buckets with bucket_max >= 75, which is buckets 2 and 3 (50%)
1319        let pred = Predicate::Gt {
1320            column: "score".to_string(),
1321            value: "75".to_string(),
1322        };
1323        let sel = optimizer.estimate_selectivity(&pred, &stats);
1324        assert!(sel > 0.4 && sel < 0.6); // ~50% from histogram (2 of 4 buckets)
1325    }
1326
1327    #[test]
1328    fn test_access_path_selection() {
1329        let config = CostModelConfig::default();
1330        let optimizer = CostBasedOptimizer::new(config);
1331
1332        let stats = create_test_stats();
1333        optimizer.update_stats(stats);
1334
1335        // High selectivity predicate should use index
1336        let pred = Predicate::Eq {
1337            column: "id".to_string(),
1338            value: "12345".to_string(),
1339        };
1340        let plan = optimizer.optimize(
1341            "users",
1342            vec!["id".to_string(), "score".to_string()],
1343            Some(pred),
1344            vec![],
1345            None,
1346        );
1347
1348        match plan {
1349            PhysicalPlan::IndexSeek { index, .. } => {
1350                assert_eq!(index, "pk_users");
1351            }
1352            _ => panic!("Expected IndexSeek for equality on primary key"),
1353        }
1354    }
1355
1356    #[test]
1357    fn test_token_budget_limit() {
1358        let config = CostModelConfig::default();
1359        let optimizer = CostBasedOptimizer::new(config).with_token_budget(2048, 25.0);
1360
1361        // With 2048 token budget and 25 tokens/row:
1362        // max_rows = (2048 - 50) / 25 = ~80
1363        let plan = optimizer.optimize("users", vec!["id".to_string()], None, vec![], None);
1364
1365        match plan {
1366            PhysicalPlan::Limit { limit, .. } => {
1367                assert!(limit <= 80);
1368            }
1369            _ => panic!("Expected Limit to be injected"),
1370        }
1371    }
1372
1373    #[test]
1374    fn test_explain_output() {
1375        let config = CostModelConfig::default();
1376        let optimizer = CostBasedOptimizer::new(config);
1377
1378        let stats = create_test_stats();
1379        optimizer.update_stats(stats);
1380
1381        let plan = optimizer.optimize(
1382            "users",
1383            vec!["id".to_string(), "score".to_string()],
1384            Some(Predicate::Gt {
1385                column: "score".to_string(),
1386                value: "80".to_string(),
1387            }),
1388            vec![("score".to_string(), SortDirection::Descending)],
1389            Some(10),
1390        );
1391
1392        let explain = optimizer.explain(&plan);
1393        assert!(explain.contains("Limit"));
1394        assert!(explain.contains("Sort"));
1395    }
1396}