Skip to main content

sochdb_query/
cost_optimizer.rs

1// SPDX-License-Identifier: AGPL-3.0-or-later
2// SochDB - LLM-Optimized Embedded Database
3// Copyright (C) 2026 Sushanth Reddy Vanagala (https://github.com/sushanthpy)
4//
5// This program is free software: you can redistribute it and/or modify
6// it under the terms of the GNU Affero General Public License as published by
7// the Free Software Foundation, either version 3 of the License, or
8// (at your option) any later version.
9//
10// This program is distributed in the hope that it will be useful,
11// but WITHOUT ANY WARRANTY; without even the implied warranty of
12// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13// GNU Affero General Public License for more details.
14//
15// You should have received a copy of the GNU Affero General Public License
16// along with this program. If not, see <https://www.gnu.org/licenses/>.
17
18//! Cost-Based Query Optimizer with Cardinality Estimation (Task 6)
19//!
20//! Provides cost-based query optimization for SOCH-QL with:
21//! - Cardinality estimation using sketches (HyperLogLog, CountMin)
22//! - Index selection: compare cost(table_scan) vs cost(index_seek)
23//! - Column projection pushdown to LSCS layer
24//! - Token-budget-aware planning
25//!
26//! ## Cost Model
27//!
28//! cost(plan) = I/O_cost + CPU_cost + memory_cost
29//!
30//! I/O_cost = blocks_read × C_seq + seeks × C_random
31//! Where:
32//!   C_seq = 0.1 ms/block (sequential read)
33//!   C_random = 5 ms/seek (random seek)
34//!
35//! CPU_cost = rows_processed × C_filter + sorts × N × log(N) × C_compare
36//!
37//! ## Selectivity Estimation
38//!
39//! Uses CountMinSketch for predicate selectivity and HyperLogLog for distinct counts.
40//!
41//! ## Token Budget Planning
42//!
43//! Given max_tokens, estimates result size and injects LIMIT clause:
44//!   max_rows = (max_tokens - header_tokens) / tokens_per_row
45
46use parking_lot::RwLock;
47use std::collections::{HashMap, HashSet};
48use std::sync::Arc;
49
50// ============================================================================
51// Cost Model Constants
52// ============================================================================
53
54/// Cost model configuration with empirically-derived constants
55#[derive(Debug, Clone)]
56pub struct CostModelConfig {
57    /// Sequential I/O cost per block (ms)
58    pub c_seq: f64,
59    /// Random I/O cost per seek (ms)
60    pub c_random: f64,
61    /// CPU cost per row filter (ms)
62    pub c_filter: f64,
63    /// CPU cost per comparison during sort (ms)
64    pub c_compare: f64,
65    /// Block size in bytes
66    pub block_size: usize,
67    /// B-tree fanout for index cost estimation
68    pub btree_fanout: usize,
69    /// Memory bandwidth (bytes/ms)
70    pub memory_bandwidth: f64,
71}
72
73impl Default for CostModelConfig {
74    fn default() -> Self {
75        Self {
76            c_seq: 0.1,                // 0.1 ms per block sequential
77            c_random: 5.0,             // 5 ms per random seek
78            c_filter: 0.001,           // 0.001 ms per row filter
79            c_compare: 0.0001,         // 0.0001 ms per comparison
80            block_size: 4096,          // 4 KB blocks
81            btree_fanout: 100,         // 100 entries per B-tree node
82            memory_bandwidth: 10000.0, // 10 GB/s = 10000 bytes/ms
83        }
84    }
85}
86
87// ============================================================================
88// Statistics for Cardinality Estimation
89// ============================================================================
90
91/// Table statistics for cost estimation
92#[derive(Debug, Clone)]
93pub struct TableStats {
94    /// Table name
95    pub name: String,
96    /// Total row count
97    pub row_count: u64,
98    /// Total size in bytes
99    pub size_bytes: u64,
100    /// Column statistics
101    pub column_stats: HashMap<String, ColumnStats>,
102    /// Available indices
103    pub indices: Vec<IndexStats>,
104    /// Last update timestamp
105    pub last_updated: u64,
106}
107
108/// Column statistics
109#[derive(Debug, Clone)]
110pub struct ColumnStats {
111    /// Column name
112    pub name: String,
113    /// Distinct value count (from HyperLogLog)
114    pub distinct_count: u64,
115    /// Null count
116    pub null_count: u64,
117    /// Minimum value (if orderable)
118    pub min_value: Option<String>,
119    /// Maximum value (if orderable)
120    pub max_value: Option<String>,
121    /// Average length in bytes (for variable-length types)
122    pub avg_length: f64,
123    /// Most common values with frequencies
124    pub mcv: Vec<(String, f64)>,
125    /// Histogram buckets for range queries
126    pub histogram: Option<Histogram>,
127}
128
129/// Histogram for range selectivity estimation
130#[derive(Debug, Clone)]
131pub struct Histogram {
132    /// Bucket boundaries
133    pub boundaries: Vec<f64>,
134    /// Row count per bucket
135    pub counts: Vec<u64>,
136    /// Total rows in histogram
137    pub total_rows: u64,
138}
139
140impl Histogram {
141    /// Estimate selectivity for a range predicate
142    pub fn estimate_range_selectivity(&self, min: Option<f64>, max: Option<f64>) -> f64 {
143        if self.total_rows == 0 {
144            return 0.5; // Default
145        }
146
147        let mut selected_rows = 0u64;
148
149        for (i, &count) in self.counts.iter().enumerate() {
150            let bucket_min = if i == 0 {
151                f64::NEG_INFINITY
152            } else {
153                self.boundaries[i - 1]
154            };
155            let bucket_max = if i == self.boundaries.len() {
156                f64::INFINITY
157            } else {
158                self.boundaries[i]
159            };
160
161            let overlaps = match (min, max) {
162                (Some(min_val), Some(max_val)) => bucket_max >= min_val && bucket_min <= max_val,
163                (Some(min_val), None) => bucket_max >= min_val,
164                (None, Some(max_val)) => bucket_min <= max_val,
165                (None, None) => true,
166            };
167
168            if overlaps {
169                selected_rows += count;
170            }
171        }
172
173        selected_rows as f64 / self.total_rows as f64
174    }
175}
176
177/// Index statistics
178#[derive(Debug, Clone)]
179pub struct IndexStats {
180    /// Index name
181    pub name: String,
182    /// Indexed columns
183    pub columns: Vec<String>,
184    /// Is primary key
185    pub is_primary: bool,
186    /// Is unique
187    pub is_unique: bool,
188    /// Index type
189    pub index_type: IndexType,
190    /// Number of leaf pages
191    pub leaf_pages: u64,
192    /// Tree height (for B-tree)
193    pub height: u32,
194    /// Average entries per leaf page
195    pub avg_leaf_density: f64,
196}
197
198/// Index types
199#[derive(Debug, Clone, Copy, PartialEq, Eq)]
200pub enum IndexType {
201    BTree,
202    Hash,
203    LSM,
204    Learned,
205    Vector,
206    Bloom,
207}
208
209// ============================================================================
210// Query Predicates and Operations
211// ============================================================================
212
213/// Query predicate for cost estimation
214#[derive(Debug, Clone)]
215pub enum Predicate {
216    /// Equality: column = value
217    Eq { column: String, value: String },
218    /// Inequality: column != value
219    Ne { column: String, value: String },
220    /// Less than: column < value
221    Lt { column: String, value: String },
222    /// Less than or equal: column <= value
223    Le { column: String, value: String },
224    /// Greater than: column > value
225    Gt { column: String, value: String },
226    /// Greater than or equal: column >= value
227    Ge { column: String, value: String },
228    /// Between: column BETWEEN min AND max
229    Between {
230        column: String,
231        min: String,
232        max: String,
233    },
234    /// In list: column IN (v1, v2, ...)
235    In { column: String, values: Vec<String> },
236    /// Like: column LIKE pattern
237    Like { column: String, pattern: String },
238    /// Is null: column IS NULL
239    IsNull { column: String },
240    /// Is not null: column IS NOT NULL
241    IsNotNull { column: String },
242    /// And: pred1 AND pred2
243    And(Box<Predicate>, Box<Predicate>),
244    /// Or: pred1 OR pred2
245    Or(Box<Predicate>, Box<Predicate>),
246    /// Not: NOT pred
247    Not(Box<Predicate>),
248}
249
250impl Predicate {
251    /// Get columns referenced by this predicate
252    pub fn referenced_columns(&self) -> HashSet<String> {
253        let mut cols = HashSet::new();
254        self.collect_columns(&mut cols);
255        cols
256    }
257
258    fn collect_columns(&self, cols: &mut HashSet<String>) {
259        match self {
260            Self::Eq { column, .. }
261            | Self::Ne { column, .. }
262            | Self::Lt { column, .. }
263            | Self::Le { column, .. }
264            | Self::Gt { column, .. }
265            | Self::Ge { column, .. }
266            | Self::Between { column, .. }
267            | Self::In { column, .. }
268            | Self::Like { column, .. }
269            | Self::IsNull { column }
270            | Self::IsNotNull { column } => {
271                cols.insert(column.clone());
272            }
273            Self::And(left, right) | Self::Or(left, right) => {
274                left.collect_columns(cols);
275                right.collect_columns(cols);
276            }
277            Self::Not(inner) => inner.collect_columns(cols),
278        }
279    }
280}
281
282// ============================================================================
283// Physical Plan Operators
284// ============================================================================
285
286/// Physical query plan node
287#[derive(Debug, Clone)]
288pub enum PhysicalPlan {
289    /// Table scan (full or partial)
290    TableScan {
291        table: String,
292        columns: Vec<String>,
293        predicate: Option<Box<Predicate>>,
294        estimated_rows: u64,
295        estimated_cost: f64,
296    },
297    /// Index seek
298    IndexSeek {
299        table: String,
300        index: String,
301        columns: Vec<String>,
302        key_range: KeyRange,
303        predicate: Option<Box<Predicate>>,
304        estimated_rows: u64,
305        estimated_cost: f64,
306    },
307    /// Filter operator
308    Filter {
309        input: Box<PhysicalPlan>,
310        predicate: Predicate,
311        estimated_rows: u64,
312        estimated_cost: f64,
313    },
314    /// Project operator (column subset)
315    Project {
316        input: Box<PhysicalPlan>,
317        columns: Vec<String>,
318        estimated_cost: f64,
319    },
320    /// Sort operator
321    Sort {
322        input: Box<PhysicalPlan>,
323        order_by: Vec<(String, SortDirection)>,
324        estimated_cost: f64,
325    },
326    /// Limit operator
327    Limit {
328        input: Box<PhysicalPlan>,
329        limit: u64,
330        offset: u64,
331        estimated_cost: f64,
332    },
333    /// Nested loop join
334    NestedLoopJoin {
335        outer: Box<PhysicalPlan>,
336        inner: Box<PhysicalPlan>,
337        condition: Predicate,
338        join_type: JoinType,
339        estimated_rows: u64,
340        estimated_cost: f64,
341    },
342    /// Hash join
343    HashJoin {
344        build: Box<PhysicalPlan>,
345        probe: Box<PhysicalPlan>,
346        build_keys: Vec<String>,
347        probe_keys: Vec<String>,
348        join_type: JoinType,
349        estimated_rows: u64,
350        estimated_cost: f64,
351    },
352    /// Merge join
353    MergeJoin {
354        left: Box<PhysicalPlan>,
355        right: Box<PhysicalPlan>,
356        left_keys: Vec<String>,
357        right_keys: Vec<String>,
358        join_type: JoinType,
359        estimated_rows: u64,
360        estimated_cost: f64,
361    },
362    /// Aggregate operator
363    Aggregate {
364        input: Box<PhysicalPlan>,
365        group_by: Vec<String>,
366        aggregates: Vec<AggregateExpr>,
367        estimated_rows: u64,
368        estimated_cost: f64,
369    },
370}
371
372/// Key range for index seeks
373#[derive(Debug, Clone)]
374pub struct KeyRange {
375    pub start: Option<Vec<u8>>,
376    pub end: Option<Vec<u8>>,
377    pub start_inclusive: bool,
378    pub end_inclusive: bool,
379}
380
381impl KeyRange {
382    pub fn all() -> Self {
383        Self {
384            start: None,
385            end: None,
386            start_inclusive: true,
387            end_inclusive: true,
388        }
389    }
390
391    pub fn point(key: Vec<u8>) -> Self {
392        Self {
393            start: Some(key.clone()),
394            end: Some(key),
395            start_inclusive: true,
396            end_inclusive: true,
397        }
398    }
399
400    pub fn range(start: Option<Vec<u8>>, end: Option<Vec<u8>>, inclusive: bool) -> Self {
401        Self {
402            start,
403            end,
404            start_inclusive: inclusive,
405            end_inclusive: inclusive,
406        }
407    }
408}
409
410/// Sort direction
411#[derive(Debug, Clone, Copy, PartialEq, Eq)]
412pub enum SortDirection {
413    Ascending,
414    Descending,
415}
416
417/// Join type
418#[derive(Debug, Clone, Copy, PartialEq, Eq)]
419pub enum JoinType {
420    Inner,
421    Left,
422    Right,
423    Full,
424    Cross,
425}
426
427/// Aggregate expression
428#[derive(Debug, Clone)]
429pub struct AggregateExpr {
430    pub function: AggregateFunction,
431    pub column: Option<String>,
432    pub alias: String,
433}
434
435/// Aggregate functions
436#[derive(Debug, Clone, Copy, PartialEq, Eq)]
437pub enum AggregateFunction {
438    Count,
439    Sum,
440    Avg,
441    Min,
442    Max,
443    CountDistinct,
444}
445
446// ============================================================================
447// Cost-Based Query Optimizer
448// ============================================================================
449
450/// Cost-based query optimizer
451pub struct CostBasedOptimizer {
452    /// Cost model configuration
453    config: CostModelConfig,
454    /// Table statistics cache
455    stats_cache: Arc<RwLock<HashMap<String, TableStats>>>,
456    /// Token budget for result limiting
457    token_budget: Option<u64>,
458    /// Estimated tokens per row
459    tokens_per_row: f64,
460}
461
462impl CostBasedOptimizer {
463    pub fn new(config: CostModelConfig) -> Self {
464        Self {
465            config,
466            stats_cache: Arc::new(RwLock::new(HashMap::new())),
467            token_budget: None,
468            tokens_per_row: 25.0, // Default estimate
469        }
470    }
471
472    /// Set token budget for result limiting
473    pub fn with_token_budget(mut self, budget: u64, tokens_per_row: f64) -> Self {
474        self.token_budget = Some(budget);
475        self.tokens_per_row = tokens_per_row;
476        self
477    }
478
479    /// Update table statistics
480    pub fn update_stats(&self, stats: TableStats) {
481        self.stats_cache.write().insert(stats.name.clone(), stats);
482    }
483
484    /// Get table statistics
485    pub fn get_stats(&self, table: &str) -> Option<TableStats> {
486        self.stats_cache.read().get(table).cloned()
487    }
488
489    /// Optimize a SELECT query
490    pub fn optimize(
491        &self,
492        table: &str,
493        columns: Vec<String>,
494        predicate: Option<Predicate>,
495        order_by: Vec<(String, SortDirection)>,
496        limit: Option<u64>,
497    ) -> PhysicalPlan {
498        let stats = self.get_stats(table);
499
500        // Calculate token-aware limit
501        let effective_limit = self.calculate_token_limit(limit);
502
503        // Get best access path (scan vs index)
504        let mut plan = self.choose_access_path(table, &columns, predicate.as_ref(), &stats);
505
506        // Apply column projection pushdown
507        plan = self.apply_projection_pushdown(plan, columns.clone());
508
509        // Apply sorting if needed
510        if !order_by.is_empty() {
511            plan = self.add_sort(plan, order_by, &stats);
512        }
513
514        // Apply limit
515        if let Some(lim) = effective_limit {
516            plan = PhysicalPlan::Limit {
517                estimated_cost: 0.0,
518                input: Box::new(plan),
519                limit: lim,
520                offset: 0,
521            };
522        }
523
524        plan
525    }
526
527    /// Calculate token-aware limit
528    fn calculate_token_limit(&self, user_limit: Option<u64>) -> Option<u64> {
529        match (self.token_budget, user_limit) {
530            (Some(budget), Some(limit)) => {
531                let header_tokens = 50u64;
532                let max_rows = ((budget - header_tokens) as f64 / self.tokens_per_row) as u64;
533                Some(limit.min(max_rows))
534            }
535            (Some(budget), None) => {
536                let header_tokens = 50u64;
537                let max_rows = ((budget - header_tokens) as f64 / self.tokens_per_row) as u64;
538                Some(max_rows)
539            }
540            (None, limit) => limit,
541        }
542    }
543
544    /// Choose best access path (table scan vs index seek)
545    fn choose_access_path(
546        &self,
547        table: &str,
548        columns: &[String],
549        predicate: Option<&Predicate>,
550        stats: &Option<TableStats>,
551    ) -> PhysicalPlan {
552        let row_count = stats.as_ref().map(|s| s.row_count).unwrap_or(10000);
553        let size_bytes = stats
554            .as_ref()
555            .map(|s| s.size_bytes)
556            .unwrap_or(row_count * 100);
557
558        // Calculate table scan cost
559        let scan_cost = self.estimate_scan_cost(row_count, size_bytes, predicate);
560
561        // Try to find a suitable index
562        let mut best_index_cost = f64::MAX;
563        let mut best_index: Option<&IndexStats> = None;
564
565        if let Some(table_stats) = stats.as_ref()
566            && let Some(pred) = predicate
567        {
568            let pred_columns = pred.referenced_columns();
569
570            for index in &table_stats.indices {
571                if self.index_covers_predicate(index, &pred_columns) {
572                    let selectivity = self.estimate_selectivity(pred, table_stats);
573                    let index_cost = self.estimate_index_cost(index, row_count, selectivity);
574
575                    if index_cost < best_index_cost {
576                        best_index_cost = index_cost;
577                        best_index = Some(index);
578                    }
579                }
580            }
581        }
582
583        // Choose cheaper option
584        if best_index_cost < scan_cost {
585            let index = best_index.unwrap();
586            let selectivity = predicate
587                .map(|p| self.estimate_selectivity(p, stats.as_ref().unwrap()))
588                .unwrap_or(1.0);
589
590            PhysicalPlan::IndexSeek {
591                table: table.to_string(),
592                index: index.name.clone(),
593                columns: columns.to_vec(),
594                key_range: KeyRange::all(), // Simplified
595                predicate: predicate.map(|p| Box::new(p.clone())),
596                estimated_rows: (row_count as f64 * selectivity) as u64,
597                estimated_cost: best_index_cost,
598            }
599        } else {
600            PhysicalPlan::TableScan {
601                table: table.to_string(),
602                columns: columns.to_vec(),
603                predicate: predicate.map(|p| Box::new(p.clone())),
604                estimated_rows: row_count,
605                estimated_cost: scan_cost,
606            }
607        }
608    }
609
610    /// Check if index covers predicate columns
611    fn index_covers_predicate(&self, index: &IndexStats, pred_columns: &HashSet<String>) -> bool {
612        // Index is useful if it covers at least the first column of the predicate
613        if let Some(first_col) = index.columns.first() {
614            pred_columns.contains(first_col)
615        } else {
616            false
617        }
618    }
619
620    /// Estimate table scan cost
621    fn estimate_scan_cost(
622        &self,
623        row_count: u64,
624        size_bytes: u64,
625        predicate: Option<&Predicate>,
626    ) -> f64 {
627        let blocks = (size_bytes as f64 / self.config.block_size as f64).ceil() as u64;
628
629        // I/O cost: sequential read
630        let io_cost = blocks as f64 * self.config.c_seq;
631
632        // CPU cost: filter all rows
633        let selectivity = predicate.map(|_| 0.1).unwrap_or(1.0);
634        let cpu_cost = row_count as f64 * self.config.c_filter * selectivity;
635
636        io_cost + cpu_cost
637    }
638
639    /// Estimate index seek cost
640    ///
641    /// Index cost = tree_traversal + leaf_scan + row_fetch
642    fn estimate_index_cost(&self, index: &IndexStats, total_rows: u64, selectivity: f64) -> f64 {
643        // Tree traversal cost (random I/O for each level)
644        let tree_cost = index.height as f64 * self.config.c_random;
645
646        // Leaf scan cost (sequential for matching range)
647        let matching_rows = (total_rows as f64 * selectivity) as u64;
648        let leaf_pages_scanned = (matching_rows as f64 / index.avg_leaf_density).ceil() as u64;
649        let leaf_cost = leaf_pages_scanned as f64 * self.config.c_seq;
650
651        // Row fetch cost (random if not clustered)
652        let fetch_cost = if index.is_primary {
653            0.0 // Clustered index, no extra fetch
654        } else {
655            matching_rows.min(1000) as f64 * self.config.c_random * 0.1 // Batch optimization
656        };
657
658        tree_cost + leaf_cost + fetch_cost
659    }
660
661    /// Estimate predicate selectivity
662    #[allow(clippy::only_used_in_recursion)]
663    fn estimate_selectivity(&self, predicate: &Predicate, stats: &TableStats) -> f64 {
664        match predicate {
665            Predicate::Eq { column, value } => {
666                if let Some(col_stats) = stats.column_stats.get(column) {
667                    // Check MCV first
668                    for (mcv_val, freq) in &col_stats.mcv {
669                        if mcv_val == value {
670                            return *freq;
671                        }
672                    }
673                    // Otherwise use uniform distribution
674                    1.0 / col_stats.distinct_count.max(1) as f64
675                } else {
676                    0.1 // Default 10%
677                }
678            }
679            Predicate::Ne { .. } => 0.9, // 90% pass
680            Predicate::Lt { column, value }
681            | Predicate::Le { column, value }
682            | Predicate::Gt { column, value }
683            | Predicate::Ge { column, value } => {
684                if let Some(col_stats) = stats.column_stats.get(column) {
685                    if let Some(ref hist) = col_stats.histogram {
686                        let val: f64 = value.parse().unwrap_or(0.0);
687                        match predicate {
688                            Predicate::Lt { .. } | Predicate::Le { .. } => {
689                                hist.estimate_range_selectivity(None, Some(val))
690                            }
691                            _ => hist.estimate_range_selectivity(Some(val), None),
692                        }
693                    } else {
694                        0.25 // Default 25%
695                    }
696                } else {
697                    0.25
698                }
699            }
700            Predicate::Between { column, min, max } => {
701                if let Some(col_stats) = stats.column_stats.get(column) {
702                    if let Some(ref hist) = col_stats.histogram {
703                        let min_val: f64 = min.parse().unwrap_or(0.0);
704                        let max_val: f64 = max.parse().unwrap_or(f64::MAX);
705                        hist.estimate_range_selectivity(Some(min_val), Some(max_val))
706                    } else {
707                        0.2
708                    }
709                } else {
710                    0.2
711                }
712            }
713            Predicate::In { column, values } => {
714                if let Some(col_stats) = stats.column_stats.get(column) {
715                    (values.len() as f64 / col_stats.distinct_count.max(1) as f64).min(1.0)
716                } else {
717                    (values.len() as f64 * 0.1).min(0.5)
718                }
719            }
720            Predicate::Like { .. } => 0.15, // Default 15%
721            Predicate::IsNull { column } => {
722                if let Some(col_stats) = stats.column_stats.get(column) {
723                    col_stats.null_count as f64 / stats.row_count.max(1) as f64
724                } else {
725                    0.01
726                }
727            }
728            Predicate::IsNotNull { column } => {
729                if let Some(col_stats) = stats.column_stats.get(column) {
730                    1.0 - (col_stats.null_count as f64 / stats.row_count.max(1) as f64)
731                } else {
732                    0.99
733                }
734            }
735            Predicate::And(left, right) => {
736                // Assume independence
737                self.estimate_selectivity(left, stats) * self.estimate_selectivity(right, stats)
738            }
739            Predicate::Or(left, right) => {
740                let s1 = self.estimate_selectivity(left, stats);
741                let s2 = self.estimate_selectivity(right, stats);
742                // P(A or B) = P(A) + P(B) - P(A and B)
743                (s1 + s2 - s1 * s2).min(1.0)
744            }
745            Predicate::Not(inner) => 1.0 - self.estimate_selectivity(inner, stats),
746        }
747    }
748
749    /// Apply column projection pushdown
750    fn apply_projection_pushdown(&self, plan: PhysicalPlan, columns: Vec<String>) -> PhysicalPlan {
751        // If plan already has projection, merge; otherwise add
752        match plan {
753            PhysicalPlan::TableScan {
754                table,
755                predicate,
756                estimated_rows,
757                estimated_cost,
758                ..
759            } => {
760                PhysicalPlan::TableScan {
761                    table,
762                    columns, // Pushed down columns
763                    predicate,
764                    estimated_rows,
765                    estimated_cost: estimated_cost * 0.2, // Reduce cost estimate
766                }
767            }
768            PhysicalPlan::IndexSeek {
769                table,
770                index,
771                key_range,
772                predicate,
773                estimated_rows,
774                estimated_cost,
775                ..
776            } => {
777                PhysicalPlan::IndexSeek {
778                    table,
779                    index,
780                    columns, // Pushed down columns
781                    key_range,
782                    predicate,
783                    estimated_rows,
784                    estimated_cost,
785                }
786            }
787            other => PhysicalPlan::Project {
788                input: Box::new(other),
789                columns,
790                estimated_cost: 0.0,
791            },
792        }
793    }
794
795    /// Add sort operator
796    fn add_sort(
797        &self,
798        plan: PhysicalPlan,
799        order_by: Vec<(String, SortDirection)>,
800        _stats: &Option<TableStats>,
801    ) -> PhysicalPlan {
802        let estimated_rows = self.get_plan_rows(&plan);
803        let sort_cost = if estimated_rows > 0 {
804            estimated_rows as f64 * (estimated_rows as f64).log2() * self.config.c_compare
805        } else {
806            0.0
807        };
808
809        PhysicalPlan::Sort {
810            input: Box::new(plan),
811            order_by,
812            estimated_cost: sort_cost,
813        }
814    }
815
816    /// Get estimated rows from a plan
817    #[allow(clippy::only_used_in_recursion)]
818    fn get_plan_rows(&self, plan: &PhysicalPlan) -> u64 {
819        match plan {
820            PhysicalPlan::TableScan { estimated_rows, .. }
821            | PhysicalPlan::IndexSeek { estimated_rows, .. }
822            | PhysicalPlan::Filter { estimated_rows, .. }
823            | PhysicalPlan::Aggregate { estimated_rows, .. }
824            | PhysicalPlan::NestedLoopJoin { estimated_rows, .. }
825            | PhysicalPlan::HashJoin { estimated_rows, .. }
826            | PhysicalPlan::MergeJoin { estimated_rows, .. } => *estimated_rows,
827            PhysicalPlan::Project { input, .. } | PhysicalPlan::Sort { input, .. } => {
828                self.get_plan_rows(input)
829            }
830            PhysicalPlan::Limit { limit, .. } => *limit,
831        }
832    }
833
834    /// Get estimated cost from a plan
835    #[allow(clippy::only_used_in_recursion)]
836    pub fn get_plan_cost(&self, plan: &PhysicalPlan) -> f64 {
837        match plan {
838            PhysicalPlan::TableScan { estimated_cost, .. } => *estimated_cost,
839            PhysicalPlan::IndexSeek { estimated_cost, .. } => *estimated_cost,
840            PhysicalPlan::Filter {
841                estimated_cost,
842                input,
843                ..
844            } => *estimated_cost + self.get_plan_cost(input),
845            PhysicalPlan::Project {
846                estimated_cost,
847                input,
848                ..
849            } => *estimated_cost + self.get_plan_cost(input),
850            PhysicalPlan::Sort {
851                estimated_cost,
852                input,
853                ..
854            } => *estimated_cost + self.get_plan_cost(input),
855            PhysicalPlan::Limit {
856                estimated_cost,
857                input,
858                ..
859            } => *estimated_cost + self.get_plan_cost(input),
860            PhysicalPlan::NestedLoopJoin {
861                estimated_cost,
862                outer,
863                inner,
864                ..
865            } => *estimated_cost + self.get_plan_cost(outer) + self.get_plan_cost(inner),
866            PhysicalPlan::HashJoin {
867                estimated_cost,
868                build,
869                probe,
870                ..
871            } => *estimated_cost + self.get_plan_cost(build) + self.get_plan_cost(probe),
872            PhysicalPlan::MergeJoin {
873                estimated_cost,
874                left,
875                right,
876                ..
877            } => *estimated_cost + self.get_plan_cost(left) + self.get_plan_cost(right),
878            PhysicalPlan::Aggregate {
879                estimated_cost,
880                input,
881                ..
882            } => *estimated_cost + self.get_plan_cost(input),
883        }
884    }
885
886    /// Generate EXPLAIN output
887    pub fn explain(&self, plan: &PhysicalPlan) -> String {
888        self.explain_impl(plan, 0)
889    }
890
891    fn explain_impl(&self, plan: &PhysicalPlan, indent: usize) -> String {
892        let prefix = "  ".repeat(indent);
893        let cost = self.get_plan_cost(plan);
894
895        match plan {
896            PhysicalPlan::TableScan {
897                table,
898                columns,
899                estimated_rows,
900                ..
901            } => {
902                format!(
903                    "{}TableScan [table={}, columns={:?}, rows={}, cost={:.2}ms]",
904                    prefix, table, columns, estimated_rows, cost
905                )
906            }
907            PhysicalPlan::IndexSeek {
908                table,
909                index,
910                columns,
911                estimated_rows,
912                ..
913            } => {
914                format!(
915                    "{}IndexSeek [table={}, index={}, columns={:?}, rows={}, cost={:.2}ms]",
916                    prefix, table, index, columns, estimated_rows, cost
917                )
918            }
919            PhysicalPlan::Filter {
920                input,
921                estimated_rows,
922                ..
923            } => {
924                format!(
925                    "{}Filter [rows={}, cost={:.2}ms]\n{}",
926                    prefix,
927                    estimated_rows,
928                    cost,
929                    self.explain_impl(input, indent + 1)
930                )
931            }
932            PhysicalPlan::Project { input, columns, .. } => {
933                format!(
934                    "{}Project [columns={:?}, cost={:.2}ms]\n{}",
935                    prefix,
936                    columns,
937                    cost,
938                    self.explain_impl(input, indent + 1)
939                )
940            }
941            PhysicalPlan::Sort {
942                input, order_by, ..
943            } => {
944                let order: Vec<_> = order_by
945                    .iter()
946                    .map(|(c, d)| format!("{} {:?}", c, d))
947                    .collect();
948                format!(
949                    "{}Sort [order={:?}, cost={:.2}ms]\n{}",
950                    prefix,
951                    order,
952                    cost,
953                    self.explain_impl(input, indent + 1)
954                )
955            }
956            PhysicalPlan::Limit {
957                input,
958                limit,
959                offset,
960                ..
961            } => {
962                format!(
963                    "{}Limit [limit={}, offset={}, cost={:.2}ms]\n{}",
964                    prefix,
965                    limit,
966                    offset,
967                    cost,
968                    self.explain_impl(input, indent + 1)
969                )
970            }
971            PhysicalPlan::HashJoin {
972                build,
973                probe,
974                join_type,
975                estimated_rows,
976                ..
977            } => {
978                format!(
979                    "{}HashJoin [type={:?}, rows={}, cost={:.2}ms]\n{}\n{}",
980                    prefix,
981                    join_type,
982                    estimated_rows,
983                    cost,
984                    self.explain_impl(build, indent + 1),
985                    self.explain_impl(probe, indent + 1)
986                )
987            }
988            PhysicalPlan::MergeJoin {
989                left,
990                right,
991                join_type,
992                estimated_rows,
993                ..
994            } => {
995                format!(
996                    "{}MergeJoin [type={:?}, rows={}, cost={:.2}ms]\n{}\n{}",
997                    prefix,
998                    join_type,
999                    estimated_rows,
1000                    cost,
1001                    self.explain_impl(left, indent + 1),
1002                    self.explain_impl(right, indent + 1)
1003                )
1004            }
1005            PhysicalPlan::NestedLoopJoin {
1006                outer,
1007                inner,
1008                join_type,
1009                estimated_rows,
1010                ..
1011            } => {
1012                format!(
1013                    "{}NestedLoopJoin [type={:?}, rows={}, cost={:.2}ms]\n{}\n{}",
1014                    prefix,
1015                    join_type,
1016                    estimated_rows,
1017                    cost,
1018                    self.explain_impl(outer, indent + 1),
1019                    self.explain_impl(inner, indent + 1)
1020                )
1021            }
1022            PhysicalPlan::Aggregate {
1023                input,
1024                group_by,
1025                aggregates,
1026                estimated_rows,
1027                ..
1028            } => {
1029                let aggs: Vec<_> = aggregates
1030                    .iter()
1031                    .map(|a| format!("{:?}({})", a.function, a.column.as_deref().unwrap_or("*")))
1032                    .collect();
1033                format!(
1034                    "{}Aggregate [group_by={:?}, aggs={:?}, rows={}, cost={:.2}ms]\n{}",
1035                    prefix,
1036                    group_by,
1037                    aggs,
1038                    estimated_rows,
1039                    cost,
1040                    self.explain_impl(input, indent + 1)
1041                )
1042            }
1043        }
1044    }
1045}
1046
1047// ============================================================================
1048// Join Order Optimizer (Dynamic Programming)
1049// ============================================================================
1050
1051/// Join order optimizer using dynamic programming
1052pub struct JoinOrderOptimizer {
1053    /// Table statistics
1054    stats: HashMap<String, TableStats>,
1055    /// Cost model
1056    config: CostModelConfig,
1057}
1058
1059impl JoinOrderOptimizer {
1060    pub fn new(config: CostModelConfig) -> Self {
1061        Self {
1062            stats: HashMap::new(),
1063            config,
1064        }
1065    }
1066
1067    /// Add table statistics
1068    pub fn add_stats(&mut self, stats: TableStats) {
1069        self.stats.insert(stats.name.clone(), stats);
1070    }
1071
1072    /// Find optimal join order using dynamic programming
1073    ///
1074    /// Time: O(2^n × n^2) where n = number of tables
1075    /// Practical for n ≤ 10
1076    pub fn find_optimal_order(
1077        &self,
1078        tables: &[String],
1079        join_conditions: &[(String, String, String, String)], // (table1, col1, table2, col2)
1080    ) -> Vec<(String, String)> {
1081        let n = tables.len();
1082        if n <= 1 {
1083            return vec![];
1084        }
1085
1086        // dp[mask] = (cost, join_order)
1087        let mut dp: HashMap<u32, (f64, Vec<(String, String)>)> = HashMap::new();
1088
1089        // Base case: single tables
1090        for (i, _table) in tables.iter().enumerate() {
1091            let mask = 1u32 << i;
1092            dp.insert(mask, (0.0, vec![]));
1093        }
1094
1095        // Build up larger subsets
1096        for size in 2..=n {
1097            for mask in 0..(1u32 << n) {
1098                if mask.count_ones() != size as u32 {
1099                    continue;
1100                }
1101
1102                let mut best_cost = f64::MAX;
1103                let mut best_order = vec![];
1104
1105                // Try all ways to split into two non-empty subsets
1106                for sub in 1..mask {
1107                    if sub & mask != sub || sub == 0 {
1108                        continue;
1109                    }
1110                    let other = mask ^ sub;
1111                    if other == 0 {
1112                        continue;
1113                    }
1114
1115                    // Check if there's a join between sub and other
1116                    if !self.has_join_condition(tables, sub, other, join_conditions) {
1117                        continue;
1118                    }
1119
1120                    if let (Some((cost1, order1)), Some((cost2, order2))) =
1121                        (dp.get(&sub), dp.get(&other))
1122                    {
1123                        let join_cost = self.estimate_join_cost(tables, sub, other);
1124                        let total_cost = cost1 + cost2 + join_cost;
1125
1126                        if total_cost < best_cost {
1127                            best_cost = total_cost;
1128                            best_order = order1.clone();
1129                            best_order.extend(order2.clone());
1130
1131                            // Add the join
1132                            let (t1, t2) =
1133                                self.get_join_tables(tables, sub, other, join_conditions);
1134                            if let Some((t1, t2)) = Some((t1, t2)) {
1135                                best_order.push((t1, t2));
1136                            }
1137                        }
1138                    }
1139                }
1140
1141                if best_cost < f64::MAX {
1142                    dp.insert(mask, (best_cost, best_order));
1143                }
1144            }
1145        }
1146
1147        let full_mask = (1u32 << n) - 1;
1148        dp.get(&full_mask)
1149            .map(|(_, order)| order.clone())
1150            .unwrap_or_default()
1151    }
1152
1153    fn has_join_condition(
1154        &self,
1155        tables: &[String],
1156        mask1: u32,
1157        mask2: u32,
1158        conditions: &[(String, String, String, String)],
1159    ) -> bool {
1160        for (t1, _, t2, _) in conditions {
1161            let idx1 = tables.iter().position(|t| t == t1);
1162            let idx2 = tables.iter().position(|t| t == t2);
1163
1164            if let (Some(i1), Some(i2)) = (idx1, idx2) {
1165                let in_mask1 = (mask1 >> i1) & 1 == 1;
1166                let in_mask2 = (mask2 >> i2) & 1 == 1;
1167
1168                if in_mask1 && in_mask2 {
1169                    return true;
1170                }
1171            }
1172        }
1173        false
1174    }
1175
1176    fn get_join_tables(
1177        &self,
1178        tables: &[String],
1179        mask1: u32,
1180        mask2: u32,
1181        conditions: &[(String, String, String, String)],
1182    ) -> (String, String) {
1183        for (t1, _, t2, _) in conditions {
1184            let idx1 = tables.iter().position(|t| t == t1);
1185            let idx2 = tables.iter().position(|t| t == t2);
1186
1187            if let (Some(i1), Some(i2)) = (idx1, idx2) {
1188                let t1_in_mask1 = (mask1 >> i1) & 1 == 1;
1189                let t2_in_mask2 = (mask2 >> i2) & 1 == 1;
1190
1191                if t1_in_mask1 && t2_in_mask2 {
1192                    return (t1.clone(), t2.clone());
1193                }
1194            }
1195        }
1196        (String::new(), String::new())
1197    }
1198
1199    fn estimate_join_cost(&self, tables: &[String], mask1: u32, mask2: u32) -> f64 {
1200        let rows1 = self.estimate_rows_for_mask(tables, mask1);
1201        let rows2 = self.estimate_rows_for_mask(tables, mask2);
1202
1203        // Hash join cost estimate
1204        // Build cost + probe cost
1205        let build_cost = rows1 as f64 * self.config.c_filter;
1206        let probe_cost = rows2 as f64 * self.config.c_filter;
1207
1208        build_cost + probe_cost
1209    }
1210
1211    fn estimate_rows_for_mask(&self, tables: &[String], mask: u32) -> u64 {
1212        let mut total = 1u64;
1213
1214        for (i, table) in tables.iter().enumerate() {
1215            if (mask >> i) & 1 == 1 {
1216                let rows = self.stats.get(table).map(|s| s.row_count).unwrap_or(1000);
1217                total = total.saturating_mul(rows);
1218            }
1219        }
1220
1221        // Apply default selectivity for joins
1222        let num_tables = mask.count_ones();
1223        if num_tables > 1 {
1224            total = (total as f64 * 0.1f64.powi(num_tables as i32 - 1)) as u64;
1225        }
1226
1227        total.max(1)
1228    }
1229}
1230
1231// ============================================================================
1232// Tests
1233// ============================================================================
1234
1235#[cfg(test)]
1236mod tests {
1237    use super::*;
1238
1239    fn create_test_stats() -> TableStats {
1240        let mut column_stats = HashMap::new();
1241        column_stats.insert(
1242            "id".to_string(),
1243            ColumnStats {
1244                name: "id".to_string(),
1245                distinct_count: 100000,
1246                null_count: 0,
1247                min_value: Some("1".to_string()),
1248                max_value: Some("100000".to_string()),
1249                avg_length: 8.0,
1250                mcv: vec![],
1251                histogram: None,
1252            },
1253        );
1254        column_stats.insert(
1255            "score".to_string(),
1256            ColumnStats {
1257                name: "score".to_string(),
1258                distinct_count: 100,
1259                null_count: 1000,
1260                min_value: Some("0".to_string()),
1261                max_value: Some("100".to_string()),
1262                avg_length: 8.0,
1263                mcv: vec![("50".to_string(), 0.05)],
1264                histogram: Some(Histogram {
1265                    boundaries: vec![25.0, 50.0, 75.0, 100.0],
1266                    counts: vec![25000, 25000, 25000, 25000],
1267                    total_rows: 100000,
1268                }),
1269            },
1270        );
1271
1272        TableStats {
1273            name: "users".to_string(),
1274            row_count: 100000,
1275            size_bytes: 10_000_000, // 10 MB
1276            column_stats,
1277            indices: vec![
1278                IndexStats {
1279                    name: "pk_users".to_string(),
1280                    columns: vec!["id".to_string()],
1281                    is_primary: true,
1282                    is_unique: true,
1283                    index_type: IndexType::BTree,
1284                    leaf_pages: 1000,
1285                    height: 3,
1286                    avg_leaf_density: 100.0,
1287                },
1288                IndexStats {
1289                    name: "idx_score".to_string(),
1290                    columns: vec!["score".to_string()],
1291                    is_primary: false,
1292                    is_unique: false,
1293                    index_type: IndexType::BTree,
1294                    leaf_pages: 500,
1295                    height: 2,
1296                    avg_leaf_density: 200.0,
1297                },
1298            ],
1299            last_updated: 0,
1300        }
1301    }
1302
1303    #[test]
1304    fn test_selectivity_estimation() {
1305        let config = CostModelConfig::default();
1306        let optimizer = CostBasedOptimizer::new(config);
1307
1308        let stats = create_test_stats();
1309        optimizer.update_stats(stats.clone());
1310
1311        // Equality predicate
1312        let pred = Predicate::Eq {
1313            column: "id".to_string(),
1314            value: "12345".to_string(),
1315        };
1316        let sel = optimizer.estimate_selectivity(&pred, &stats);
1317        assert!(sel < 0.001); // Should be very selective
1318
1319        // Range predicate with histogram
1320        // Note: For histogram boundaries [25, 50, 75, 100] with equal distribution,
1321        // Gt{75} includes buckets with bucket_max >= 75, which is buckets 2 and 3 (50%)
1322        let pred = Predicate::Gt {
1323            column: "score".to_string(),
1324            value: "75".to_string(),
1325        };
1326        let sel = optimizer.estimate_selectivity(&pred, &stats);
1327        assert!(sel > 0.4 && sel < 0.6); // ~50% from histogram (2 of 4 buckets)
1328    }
1329
1330    #[test]
1331    fn test_access_path_selection() {
1332        let config = CostModelConfig::default();
1333        let optimizer = CostBasedOptimizer::new(config);
1334
1335        let stats = create_test_stats();
1336        optimizer.update_stats(stats);
1337
1338        // High selectivity predicate should use index
1339        let pred = Predicate::Eq {
1340            column: "id".to_string(),
1341            value: "12345".to_string(),
1342        };
1343        let plan = optimizer.optimize(
1344            "users",
1345            vec!["id".to_string(), "score".to_string()],
1346            Some(pred),
1347            vec![],
1348            None,
1349        );
1350
1351        match plan {
1352            PhysicalPlan::IndexSeek { index, .. } => {
1353                assert_eq!(index, "pk_users");
1354            }
1355            _ => panic!("Expected IndexSeek for equality on primary key"),
1356        }
1357    }
1358
1359    #[test]
1360    fn test_token_budget_limit() {
1361        let config = CostModelConfig::default();
1362        let optimizer = CostBasedOptimizer::new(config).with_token_budget(2048, 25.0);
1363
1364        // With 2048 token budget and 25 tokens/row:
1365        // max_rows = (2048 - 50) / 25 = ~80
1366        let plan = optimizer.optimize("users", vec!["id".to_string()], None, vec![], None);
1367
1368        match plan {
1369            PhysicalPlan::Limit { limit, .. } => {
1370                assert!(limit <= 80);
1371            }
1372            _ => panic!("Expected Limit to be injected"),
1373        }
1374    }
1375
1376    #[test]
1377    fn test_explain_output() {
1378        let config = CostModelConfig::default();
1379        let optimizer = CostBasedOptimizer::new(config);
1380
1381        let stats = create_test_stats();
1382        optimizer.update_stats(stats);
1383
1384        let plan = optimizer.optimize(
1385            "users",
1386            vec!["id".to_string(), "score".to_string()],
1387            Some(Predicate::Gt {
1388                column: "score".to_string(),
1389                value: "80".to_string(),
1390            }),
1391            vec![("score".to_string(), SortDirection::Descending)],
1392            Some(10),
1393        );
1394
1395        let explain = optimizer.explain(&plan);
1396        assert!(explain.contains("Limit"));
1397        assert!(explain.contains("Sort"));
1398    }
1399}