1use parking_lot::RwLock;
47use std::collections::{HashMap, HashSet};
48use std::sync::Arc;
49use std::time::{SystemTime, UNIX_EPOCH};
50
51#[derive(Debug, Clone)]
57pub struct CostModelConfig {
58 pub c_seq: f64,
60 pub c_random: f64,
62 pub c_filter: f64,
64 pub c_compare: f64,
66 pub block_size: usize,
68 pub btree_fanout: usize,
70 pub memory_bandwidth: f64,
72}
73
74impl Default for CostModelConfig {
75 fn default() -> Self {
76 Self {
77 c_seq: 0.1, c_random: 5.0, c_filter: 0.001, c_compare: 0.0001, block_size: 4096, btree_fanout: 100, memory_bandwidth: 10000.0, }
85 }
86}
87
88#[derive(Debug, Clone)]
94pub struct TableStats {
95 pub name: String,
97 pub row_count: u64,
99 pub size_bytes: u64,
101 pub column_stats: HashMap<String, ColumnStats>,
103 pub indices: Vec<IndexStats>,
105 pub last_updated: u64,
107}
108
109#[derive(Debug, Clone)]
111pub struct ColumnStats {
112 pub name: String,
114 pub distinct_count: u64,
116 pub null_count: u64,
118 pub min_value: Option<String>,
120 pub max_value: Option<String>,
122 pub avg_length: f64,
124 pub mcv: Vec<(String, f64)>,
126 pub histogram: Option<Histogram>,
128}
129
130#[derive(Debug, Clone)]
132pub struct Histogram {
133 pub boundaries: Vec<f64>,
135 pub counts: Vec<u64>,
137 pub total_rows: u64,
139}
140
141impl Histogram {
142 pub fn estimate_range_selectivity(&self, min: Option<f64>, max: Option<f64>) -> f64 {
144 if self.total_rows == 0 {
145 return 0.5; }
147
148 let mut selected_rows = 0u64;
149
150 for (i, &count) in self.counts.iter().enumerate() {
151 let bucket_min = if i == 0 {
152 f64::NEG_INFINITY
153 } else {
154 self.boundaries[i - 1]
155 };
156 let bucket_max = if i == self.boundaries.len() {
157 f64::INFINITY
158 } else {
159 self.boundaries[i]
160 };
161
162 let overlaps = match (min, max) {
163 (Some(min_val), Some(max_val)) => bucket_max >= min_val && bucket_min <= max_val,
164 (Some(min_val), None) => bucket_max >= min_val,
165 (None, Some(max_val)) => bucket_min <= max_val,
166 (None, None) => true,
167 };
168
169 if overlaps {
170 selected_rows += count;
171 }
172 }
173
174 selected_rows as f64 / self.total_rows as f64
175 }
176}
177
178#[derive(Debug, Clone)]
180pub struct IndexStats {
181 pub name: String,
183 pub columns: Vec<String>,
185 pub is_primary: bool,
187 pub is_unique: bool,
189 pub index_type: IndexType,
191 pub leaf_pages: u64,
193 pub height: u32,
195 pub avg_leaf_density: f64,
197}
198
199#[derive(Debug, Clone, Copy, PartialEq, Eq)]
201pub enum IndexType {
202 BTree,
203 Hash,
204 LSM,
205 Learned,
206 Vector,
207 Bloom,
208}
209
210#[derive(Debug, Clone)]
216pub enum Predicate {
217 Eq { column: String, value: String },
219 Ne { column: String, value: String },
221 Lt { column: String, value: String },
223 Le { column: String, value: String },
225 Gt { column: String, value: String },
227 Ge { column: String, value: String },
229 Between {
231 column: String,
232 min: String,
233 max: String,
234 },
235 In { column: String, values: Vec<String> },
237 Like { column: String, pattern: String },
239 IsNull { column: String },
241 IsNotNull { column: String },
243 And(Box<Predicate>, Box<Predicate>),
245 Or(Box<Predicate>, Box<Predicate>),
247 Not(Box<Predicate>),
249}
250
251impl Predicate {
252 pub fn referenced_columns(&self) -> HashSet<String> {
254 let mut cols = HashSet::new();
255 self.collect_columns(&mut cols);
256 cols
257 }
258
259 fn collect_columns(&self, cols: &mut HashSet<String>) {
260 match self {
261 Self::Eq { column, .. }
262 | Self::Ne { column, .. }
263 | Self::Lt { column, .. }
264 | Self::Le { column, .. }
265 | Self::Gt { column, .. }
266 | Self::Ge { column, .. }
267 | Self::Between { column, .. }
268 | Self::In { column, .. }
269 | Self::Like { column, .. }
270 | Self::IsNull { column }
271 | Self::IsNotNull { column } => {
272 cols.insert(column.clone());
273 }
274 Self::And(left, right) | Self::Or(left, right) => {
275 left.collect_columns(cols);
276 right.collect_columns(cols);
277 }
278 Self::Not(inner) => inner.collect_columns(cols),
279 }
280 }
281}
282
283#[derive(Debug, Clone)]
289pub enum PhysicalPlan {
290 TableScan {
292 table: String,
293 columns: Vec<String>,
294 predicate: Option<Box<Predicate>>,
295 estimated_rows: u64,
296 estimated_cost: f64,
297 },
298 IndexSeek {
300 table: String,
301 index: String,
302 columns: Vec<String>,
303 key_range: KeyRange,
304 predicate: Option<Box<Predicate>>,
305 estimated_rows: u64,
306 estimated_cost: f64,
307 },
308 Filter {
310 input: Box<PhysicalPlan>,
311 predicate: Predicate,
312 estimated_rows: u64,
313 estimated_cost: f64,
314 },
315 Project {
317 input: Box<PhysicalPlan>,
318 columns: Vec<String>,
319 estimated_cost: f64,
320 },
321 Sort {
323 input: Box<PhysicalPlan>,
324 order_by: Vec<(String, SortDirection)>,
325 estimated_cost: f64,
326 },
327 Limit {
329 input: Box<PhysicalPlan>,
330 limit: u64,
331 offset: u64,
332 estimated_cost: f64,
333 },
334 NestedLoopJoin {
336 outer: Box<PhysicalPlan>,
337 inner: Box<PhysicalPlan>,
338 condition: Predicate,
339 join_type: JoinType,
340 estimated_rows: u64,
341 estimated_cost: f64,
342 },
343 HashJoin {
345 build: Box<PhysicalPlan>,
346 probe: Box<PhysicalPlan>,
347 build_keys: Vec<String>,
348 probe_keys: Vec<String>,
349 join_type: JoinType,
350 estimated_rows: u64,
351 estimated_cost: f64,
352 },
353 MergeJoin {
355 left: Box<PhysicalPlan>,
356 right: Box<PhysicalPlan>,
357 left_keys: Vec<String>,
358 right_keys: Vec<String>,
359 join_type: JoinType,
360 estimated_rows: u64,
361 estimated_cost: f64,
362 },
363 Aggregate {
365 input: Box<PhysicalPlan>,
366 group_by: Vec<String>,
367 aggregates: Vec<AggregateExpr>,
368 estimated_rows: u64,
369 estimated_cost: f64,
370 },
371}
372
373#[derive(Debug, Clone)]
375pub struct KeyRange {
376 pub start: Option<Vec<u8>>,
377 pub end: Option<Vec<u8>>,
378 pub start_inclusive: bool,
379 pub end_inclusive: bool,
380}
381
382impl KeyRange {
383 pub fn all() -> Self {
384 Self {
385 start: None,
386 end: None,
387 start_inclusive: true,
388 end_inclusive: true,
389 }
390 }
391
392 pub fn point(key: Vec<u8>) -> Self {
393 Self {
394 start: Some(key.clone()),
395 end: Some(key),
396 start_inclusive: true,
397 end_inclusive: true,
398 }
399 }
400
401 pub fn range(start: Option<Vec<u8>>, end: Option<Vec<u8>>, inclusive: bool) -> Self {
402 Self {
403 start,
404 end,
405 start_inclusive: inclusive,
406 end_inclusive: inclusive,
407 }
408 }
409}
410
411#[derive(Debug, Clone, Copy, PartialEq, Eq)]
413pub enum SortDirection {
414 Ascending,
415 Descending,
416}
417
418#[derive(Debug, Clone, Copy, PartialEq, Eq)]
420pub enum JoinType {
421 Inner,
422 Left,
423 Right,
424 Full,
425 Cross,
426}
427
428#[derive(Debug, Clone)]
430pub struct AggregateExpr {
431 pub function: AggregateFunction,
432 pub column: Option<String>,
433 pub alias: String,
434}
435
436#[derive(Debug, Clone, Copy, PartialEq, Eq)]
438pub enum AggregateFunction {
439 Count,
440 Sum,
441 Avg,
442 Min,
443 Max,
444 CountDistinct,
445}
446
447pub struct CostBasedOptimizer {
453 config: CostModelConfig,
455 stats_cache: Arc<RwLock<HashMap<String, TableStats>>>,
457 token_budget: Option<u64>,
459 tokens_per_row: f64,
461 plan_cache: Arc<RwLock<HashMap<u64, (PhysicalPlan, u64)>>>,
463 plan_cache_ttl_us: u64,
465}
466
467impl CostBasedOptimizer {
468 pub fn new(config: CostModelConfig) -> Self {
469 Self {
470 config,
471 stats_cache: Arc::new(RwLock::new(HashMap::new())),
472 token_budget: None,
473 tokens_per_row: 25.0, plan_cache: Arc::new(RwLock::new(HashMap::new())),
475 plan_cache_ttl_us: 5_000_000, }
477 }
478
479 pub fn with_plan_cache_ttl_ms(mut self, ttl_ms: u64) -> Self {
481 self.plan_cache_ttl_us = ttl_ms * 1000;
482 self
483 }
484
485 pub fn with_token_budget(mut self, budget: u64, tokens_per_row: f64) -> Self {
487 self.token_budget = Some(budget);
488 self.tokens_per_row = tokens_per_row;
489 self
490 }
491
492 pub fn update_stats(&self, stats: TableStats) {
494 self.stats_cache.write().insert(stats.name.clone(), stats);
495 }
496
497 pub fn get_stats(&self, table: &str) -> Option<TableStats> {
499 self.stats_cache.read().get(table).cloned()
500 }
501
502 pub fn optimize(
504 &self,
505 table: &str,
506 columns: Vec<String>,
507 predicate: Option<Predicate>,
508 order_by: Vec<(String, SortDirection)>,
509 limit: Option<u64>,
510 ) -> PhysicalPlan {
511 let stats = self.get_stats(table);
512
513 let effective_limit = self.calculate_token_limit(limit);
515
516 let mut plan = self.choose_access_path(table, &columns, predicate.as_ref(), &stats);
518
519 plan = self.apply_projection_pushdown(plan, columns.clone());
521
522 if !order_by.is_empty() {
524 plan = self.add_sort(plan, order_by, &stats);
525 }
526
527 if let Some(lim) = effective_limit {
529 plan = PhysicalPlan::Limit {
530 estimated_cost: 0.0,
531 input: Box::new(plan),
532 limit: lim,
533 offset: 0,
534 };
535 }
536
537 plan
538 }
539
540 fn calculate_token_limit(&self, user_limit: Option<u64>) -> Option<u64> {
542 match (self.token_budget, user_limit) {
543 (Some(budget), Some(limit)) => {
544 let header_tokens = 50u64;
545 let usable = budget.saturating_sub(header_tokens);
546 let max_rows = (usable as f64 / self.tokens_per_row).max(1.0) as u64;
547 Some(limit.min(max_rows))
548 }
549 (Some(budget), None) => {
550 let header_tokens = 50u64;
551 let usable = budget.saturating_sub(header_tokens);
552 let max_rows = (usable as f64 / self.tokens_per_row).max(1.0) as u64;
553 Some(max_rows)
554 }
555 (None, limit) => limit,
556 }
557 }
558
559 fn choose_access_path(
561 &self,
562 table: &str,
563 columns: &[String],
564 predicate: Option<&Predicate>,
565 stats: &Option<TableStats>,
566 ) -> PhysicalPlan {
567 let row_count = stats.as_ref().map(|s| s.row_count).unwrap_or(10000);
568 let size_bytes = stats
569 .as_ref()
570 .map(|s| s.size_bytes)
571 .unwrap_or(row_count * 100);
572
573 let scan_cost = self.estimate_scan_cost(row_count, size_bytes, predicate);
575
576 let mut best_index_cost = f64::MAX;
578 let mut best_index: Option<&IndexStats> = None;
579
580 if let Some(table_stats) = stats.as_ref()
581 && let Some(pred) = predicate
582 {
583 let pred_columns = pred.referenced_columns();
584
585 for index in &table_stats.indices {
586 if self.index_covers_predicate(index, &pred_columns) {
587 let selectivity = self.estimate_selectivity(pred, table_stats);
588 let index_cost = self.estimate_index_cost(index, row_count, selectivity);
589
590 if index_cost < best_index_cost {
591 best_index_cost = index_cost;
592 best_index = Some(index);
593 }
594 }
595 }
596 }
597
598 if best_index_cost < scan_cost {
600 let index = best_index.unwrap();
601 let selectivity = predicate
602 .map(|p| self.estimate_selectivity(p, stats.as_ref().unwrap()))
603 .unwrap_or(1.0);
604
605 PhysicalPlan::IndexSeek {
606 table: table.to_string(),
607 index: index.name.clone(),
608 columns: columns.to_vec(),
609 key_range: predicate
610 .map(|p| Self::derive_key_range(p))
611 .unwrap_or_else(KeyRange::all),
612 predicate: predicate.map(|p| Box::new(p.clone())),
613 estimated_rows: (row_count as f64 * selectivity).max(1.0) as u64,
614 estimated_cost: best_index_cost,
615 }
616 } else {
617 PhysicalPlan::TableScan {
618 table: table.to_string(),
619 columns: columns.to_vec(),
620 predicate: predicate.map(|p| Box::new(p.clone())),
621 estimated_rows: row_count,
622 estimated_cost: scan_cost,
623 }
624 }
625 }
626
627 fn index_covers_predicate(&self, index: &IndexStats, pred_columns: &HashSet<String>) -> bool {
629 if let Some(first_col) = index.columns.first() {
631 pred_columns.contains(first_col)
632 } else {
633 false
634 }
635 }
636
637 fn estimate_scan_cost(
642 &self,
643 row_count: u64,
644 size_bytes: u64,
645 _predicate: Option<&Predicate>,
646 ) -> f64 {
647 let blocks = (size_bytes as f64 / self.config.block_size as f64)
648 .ceil()
649 .max(1.0) as u64;
650
651 let io_cost = blocks as f64 * self.config.c_seq;
653
654 let cpu_cost = row_count as f64 * self.config.c_filter;
656
657 io_cost + cpu_cost
658 }
659
660 fn estimate_index_cost(&self, index: &IndexStats, total_rows: u64, selectivity: f64) -> f64 {
664 let tree_cost = index.height as f64 * self.config.c_random;
666
667 let matching_rows = (total_rows as f64 * selectivity) as u64;
669 let leaf_pages_scanned = (matching_rows as f64 / index.avg_leaf_density).ceil() as u64;
670 let leaf_cost = leaf_pages_scanned as f64 * self.config.c_seq;
671
672 let fetch_cost = if index.is_primary {
674 0.0 } else {
676 matching_rows.min(1000) as f64 * self.config.c_random * 0.1 };
678
679 tree_cost + leaf_cost + fetch_cost
680 }
681
682 #[allow(clippy::only_used_in_recursion)]
684 fn estimate_selectivity(&self, predicate: &Predicate, stats: &TableStats) -> f64 {
685 match predicate {
686 Predicate::Eq { column, value } => {
687 if let Some(col_stats) = stats.column_stats.get(column) {
688 for (mcv_val, freq) in &col_stats.mcv {
690 if mcv_val == value {
691 return *freq;
692 }
693 }
694 1.0 / col_stats.distinct_count.max(1) as f64
696 } else {
697 0.1 }
699 }
700 Predicate::Ne { .. } => 0.9, Predicate::Lt { column, value }
702 | Predicate::Le { column, value }
703 | Predicate::Gt { column, value }
704 | Predicate::Ge { column, value } => {
705 if let Some(col_stats) = stats.column_stats.get(column) {
706 if let Some(ref hist) = col_stats.histogram {
707 let val: f64 = value.parse().unwrap_or(0.0);
708 match predicate {
709 Predicate::Lt { .. } | Predicate::Le { .. } => {
710 hist.estimate_range_selectivity(None, Some(val))
711 }
712 _ => hist.estimate_range_selectivity(Some(val), None),
713 }
714 } else {
715 0.25 }
717 } else {
718 0.25
719 }
720 }
721 Predicate::Between { column, min, max } => {
722 if let Some(col_stats) = stats.column_stats.get(column) {
723 if let Some(ref hist) = col_stats.histogram {
724 let min_val: f64 = min.parse().unwrap_or(0.0);
725 let max_val: f64 = max.parse().unwrap_or(f64::MAX);
726 hist.estimate_range_selectivity(Some(min_val), Some(max_val))
727 } else {
728 0.2
729 }
730 } else {
731 0.2
732 }
733 }
734 Predicate::In { column, values } => {
735 if let Some(col_stats) = stats.column_stats.get(column) {
736 (values.len() as f64 / col_stats.distinct_count.max(1) as f64).min(1.0)
737 } else {
738 (values.len() as f64 * 0.1).min(0.5)
739 }
740 }
741 Predicate::Like { .. } => 0.15, Predicate::IsNull { column } => {
743 if let Some(col_stats) = stats.column_stats.get(column) {
744 col_stats.null_count as f64 / stats.row_count.max(1) as f64
745 } else {
746 0.01
747 }
748 }
749 Predicate::IsNotNull { column } => {
750 if let Some(col_stats) = stats.column_stats.get(column) {
751 1.0 - (col_stats.null_count as f64 / stats.row_count.max(1) as f64)
752 } else {
753 0.99
754 }
755 }
756 Predicate::And(left, right) => {
757 self.estimate_selectivity(left, stats) * self.estimate_selectivity(right, stats)
759 }
760 Predicate::Or(left, right) => {
761 let s1 = self.estimate_selectivity(left, stats);
762 let s2 = self.estimate_selectivity(right, stats);
763 (s1 + s2 - s1 * s2).min(1.0)
765 }
766 Predicate::Not(inner) => 1.0 - self.estimate_selectivity(inner, stats),
767 }
768 }
769
770 fn derive_key_range(predicate: &Predicate) -> KeyRange {
772 match predicate {
773 Predicate::Eq { value, .. } => KeyRange::point(value.as_bytes().to_vec()),
774 Predicate::Lt { value, .. } | Predicate::Le { value, .. } => KeyRange::range(
775 None,
776 Some(value.as_bytes().to_vec()),
777 matches!(predicate, Predicate::Le { .. }),
778 ),
779 Predicate::Gt { value, .. } | Predicate::Ge { value, .. } => KeyRange::range(
780 Some(value.as_bytes().to_vec()),
781 None,
782 matches!(predicate, Predicate::Ge { .. }),
783 ),
784 Predicate::Between { min, max, .. } => KeyRange {
785 start: Some(min.as_bytes().to_vec()),
786 end: Some(max.as_bytes().to_vec()),
787 start_inclusive: true,
788 end_inclusive: true,
789 },
790 Predicate::And(left, _) => Self::derive_key_range(left),
791 _ => KeyRange::all(),
792 }
793 }
794
795 fn apply_projection_pushdown(&self, plan: PhysicalPlan, columns: Vec<String>) -> PhysicalPlan {
799 match plan {
800 PhysicalPlan::TableScan {
801 ref table,
802 predicate,
803 estimated_rows,
804 estimated_cost,
805 columns: ref all_columns,
806 ..
807 } => {
808 let col_ratio = if all_columns.is_empty() || columns.is_empty() {
810 1.0
811 } else {
812 (columns.len() as f64 / all_columns.len().max(1) as f64).clamp(0.1, 1.0)
813 };
814 PhysicalPlan::TableScan {
815 table: table.clone(),
816 columns,
817 predicate,
818 estimated_rows,
819 estimated_cost: estimated_cost * col_ratio,
820 }
821 }
822 PhysicalPlan::IndexSeek {
823 table,
824 index,
825 key_range,
826 predicate,
827 estimated_rows,
828 estimated_cost,
829 ..
830 } => {
831 PhysicalPlan::IndexSeek {
832 table,
833 index,
834 columns, key_range,
836 predicate,
837 estimated_rows,
838 estimated_cost,
839 }
840 }
841 other => PhysicalPlan::Project {
842 input: Box::new(other),
843 columns,
844 estimated_cost: 0.0,
845 },
846 }
847 }
848
849 fn add_sort(
851 &self,
852 plan: PhysicalPlan,
853 order_by: Vec<(String, SortDirection)>,
854 _stats: &Option<TableStats>,
855 ) -> PhysicalPlan {
856 let estimated_rows = self.get_plan_rows(&plan);
857 let sort_cost = if estimated_rows > 0 {
858 estimated_rows as f64 * (estimated_rows as f64).log2() * self.config.c_compare
859 } else {
860 0.0
861 };
862
863 PhysicalPlan::Sort {
864 input: Box::new(plan),
865 order_by,
866 estimated_cost: sort_cost,
867 }
868 }
869
870 #[allow(clippy::only_used_in_recursion)]
872 fn get_plan_rows(&self, plan: &PhysicalPlan) -> u64 {
873 match plan {
874 PhysicalPlan::TableScan { estimated_rows, .. }
875 | PhysicalPlan::IndexSeek { estimated_rows, .. }
876 | PhysicalPlan::Filter { estimated_rows, .. }
877 | PhysicalPlan::Aggregate { estimated_rows, .. }
878 | PhysicalPlan::NestedLoopJoin { estimated_rows, .. }
879 | PhysicalPlan::HashJoin { estimated_rows, .. }
880 | PhysicalPlan::MergeJoin { estimated_rows, .. } => *estimated_rows,
881 PhysicalPlan::Project { input, .. } | PhysicalPlan::Sort { input, .. } => {
882 self.get_plan_rows(input)
883 }
884 PhysicalPlan::Limit { limit, .. } => *limit,
885 }
886 }
887
888 #[allow(clippy::only_used_in_recursion)]
890 pub fn get_plan_cost(&self, plan: &PhysicalPlan) -> f64 {
891 match plan {
892 PhysicalPlan::TableScan { estimated_cost, .. } => *estimated_cost,
893 PhysicalPlan::IndexSeek { estimated_cost, .. } => *estimated_cost,
894 PhysicalPlan::Filter {
895 estimated_cost,
896 input,
897 ..
898 } => *estimated_cost + self.get_plan_cost(input),
899 PhysicalPlan::Project {
900 estimated_cost,
901 input,
902 ..
903 } => *estimated_cost + self.get_plan_cost(input),
904 PhysicalPlan::Sort {
905 estimated_cost,
906 input,
907 ..
908 } => *estimated_cost + self.get_plan_cost(input),
909 PhysicalPlan::Limit {
910 estimated_cost,
911 input,
912 ..
913 } => *estimated_cost + self.get_plan_cost(input),
914 PhysicalPlan::NestedLoopJoin {
915 estimated_cost,
916 outer,
917 inner,
918 ..
919 } => *estimated_cost + self.get_plan_cost(outer) + self.get_plan_cost(inner),
920 PhysicalPlan::HashJoin {
921 estimated_cost,
922 build,
923 probe,
924 ..
925 } => *estimated_cost + self.get_plan_cost(build) + self.get_plan_cost(probe),
926 PhysicalPlan::MergeJoin {
927 estimated_cost,
928 left,
929 right,
930 ..
931 } => *estimated_cost + self.get_plan_cost(left) + self.get_plan_cost(right),
932 PhysicalPlan::Aggregate {
933 estimated_cost,
934 input,
935 ..
936 } => *estimated_cost + self.get_plan_cost(input),
937 }
938 }
939
940 pub fn explain(&self, plan: &PhysicalPlan) -> String {
942 self.explain_impl(plan, 0)
943 }
944
945 fn explain_impl(&self, plan: &PhysicalPlan, indent: usize) -> String {
946 let prefix = " ".repeat(indent);
947 let cost = self.get_plan_cost(plan);
948
949 match plan {
950 PhysicalPlan::TableScan {
951 table,
952 columns,
953 estimated_rows,
954 ..
955 } => {
956 format!(
957 "{}TableScan [table={}, columns={:?}, rows={}, cost={:.2}ms]",
958 prefix, table, columns, estimated_rows, cost
959 )
960 }
961 PhysicalPlan::IndexSeek {
962 table,
963 index,
964 columns,
965 estimated_rows,
966 ..
967 } => {
968 format!(
969 "{}IndexSeek [table={}, index={}, columns={:?}, rows={}, cost={:.2}ms]",
970 prefix, table, index, columns, estimated_rows, cost
971 )
972 }
973 PhysicalPlan::Filter {
974 input,
975 estimated_rows,
976 ..
977 } => {
978 format!(
979 "{}Filter [rows={}, cost={:.2}ms]\n{}",
980 prefix,
981 estimated_rows,
982 cost,
983 self.explain_impl(input, indent + 1)
984 )
985 }
986 PhysicalPlan::Project { input, columns, .. } => {
987 format!(
988 "{}Project [columns={:?}, cost={:.2}ms]\n{}",
989 prefix,
990 columns,
991 cost,
992 self.explain_impl(input, indent + 1)
993 )
994 }
995 PhysicalPlan::Sort {
996 input, order_by, ..
997 } => {
998 let order: Vec<_> = order_by
999 .iter()
1000 .map(|(c, d)| format!("{} {:?}", c, d))
1001 .collect();
1002 format!(
1003 "{}Sort [order={:?}, cost={:.2}ms]\n{}",
1004 prefix,
1005 order,
1006 cost,
1007 self.explain_impl(input, indent + 1)
1008 )
1009 }
1010 PhysicalPlan::Limit {
1011 input,
1012 limit,
1013 offset,
1014 ..
1015 } => {
1016 format!(
1017 "{}Limit [limit={}, offset={}, cost={:.2}ms]\n{}",
1018 prefix,
1019 limit,
1020 offset,
1021 cost,
1022 self.explain_impl(input, indent + 1)
1023 )
1024 }
1025 PhysicalPlan::HashJoin {
1026 build,
1027 probe,
1028 join_type,
1029 estimated_rows,
1030 ..
1031 } => {
1032 format!(
1033 "{}HashJoin [type={:?}, rows={}, cost={:.2}ms]\n{}\n{}",
1034 prefix,
1035 join_type,
1036 estimated_rows,
1037 cost,
1038 self.explain_impl(build, indent + 1),
1039 self.explain_impl(probe, indent + 1)
1040 )
1041 }
1042 PhysicalPlan::MergeJoin {
1043 left,
1044 right,
1045 join_type,
1046 estimated_rows,
1047 ..
1048 } => {
1049 format!(
1050 "{}MergeJoin [type={:?}, rows={}, cost={:.2}ms]\n{}\n{}",
1051 prefix,
1052 join_type,
1053 estimated_rows,
1054 cost,
1055 self.explain_impl(left, indent + 1),
1056 self.explain_impl(right, indent + 1)
1057 )
1058 }
1059 PhysicalPlan::NestedLoopJoin {
1060 outer,
1061 inner,
1062 join_type,
1063 estimated_rows,
1064 ..
1065 } => {
1066 format!(
1067 "{}NestedLoopJoin [type={:?}, rows={}, cost={:.2}ms]\n{}\n{}",
1068 prefix,
1069 join_type,
1070 estimated_rows,
1071 cost,
1072 self.explain_impl(outer, indent + 1),
1073 self.explain_impl(inner, indent + 1)
1074 )
1075 }
1076 PhysicalPlan::Aggregate {
1077 input,
1078 group_by,
1079 aggregates,
1080 estimated_rows,
1081 ..
1082 } => {
1083 let aggs: Vec<_> = aggregates
1084 .iter()
1085 .map(|a| format!("{:?}({})", a.function, a.column.as_deref().unwrap_or("*")))
1086 .collect();
1087 format!(
1088 "{}Aggregate [group_by={:?}, aggs={:?}, rows={}, cost={:.2}ms]\n{}",
1089 prefix,
1090 group_by,
1091 aggs,
1092 estimated_rows,
1093 cost,
1094 self.explain_impl(input, indent + 1)
1095 )
1096 }
1097 }
1098 }
1099}
1100
1101impl CostBasedOptimizer {
1106 pub fn evict_stale_plans(&self) {
1108 let now = Self::now_us();
1109 self.plan_cache
1110 .write()
1111 .retain(|_, (_, ts)| now.saturating_sub(*ts) < self.plan_cache_ttl_us);
1112 }
1113
1114 pub fn invalidate_plan_cache(&self) {
1116 self.plan_cache.write().clear();
1117 }
1118
1119 pub fn collect_stats(
1124 &self,
1125 table_name: &str,
1126 row_count: u64,
1127 size_bytes: u64,
1128 column_values: HashMap<String, Vec<String>>,
1129 indices: Vec<IndexStats>,
1130 ) {
1131 let mut column_stats = HashMap::new();
1132 for (col_name, values) in &column_values {
1133 let distinct: HashSet<&String> = values.iter().collect();
1134 let null_count = values.iter().filter(|v| v.is_empty()).count() as u64;
1135 let avg_length = if values.is_empty() {
1136 0.0
1137 } else {
1138 values.iter().map(|v| v.len()).sum::<usize>() as f64 / values.len() as f64
1139 };
1140
1141 let is_numeric = values.iter().take(10).all(|v| v.parse::<f64>().is_ok());
1143 let histogram = if is_numeric && values.len() >= 10 {
1144 let mut nums: Vec<f64> = values.iter().filter_map(|v| v.parse().ok()).collect();
1145 nums.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
1146 let bucket_count = 10.min(nums.len());
1147 let bucket_size = nums.len() / bucket_count;
1148 let mut boundaries = Vec::new();
1149 let mut counts = Vec::new();
1150 for i in 0..bucket_count {
1151 let end = if i == bucket_count - 1 {
1152 nums.len()
1153 } else {
1154 (i + 1) * bucket_size
1155 };
1156 let start = i * bucket_size;
1157 boundaries.push(nums[end - 1]);
1158 counts.push((end - start) as u64);
1159 }
1160 Some(Histogram {
1161 boundaries,
1162 counts,
1163 total_rows: nums.len() as u64,
1164 })
1165 } else {
1166 None
1167 };
1168
1169 let mut freq_map: HashMap<&String, usize> = HashMap::new();
1171 for v in values {
1172 *freq_map.entry(v).or_insert(0) += 1;
1173 }
1174 let total = values.len() as f64;
1175 let mut mcv: Vec<(String, f64)> = freq_map
1176 .iter()
1177 .map(|(k, &v)| ((*k).clone(), v as f64 / total))
1178 .collect();
1179 mcv.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
1180 mcv.truncate(5);
1181
1182 column_stats.insert(
1183 col_name.clone(),
1184 ColumnStats {
1185 name: col_name.clone(),
1186 distinct_count: distinct.len() as u64,
1187 null_count,
1188 min_value: values.iter().min().cloned(),
1189 max_value: values.iter().max().cloned(),
1190 avg_length,
1191 mcv,
1192 histogram,
1193 },
1194 );
1195 }
1196
1197 self.update_stats(TableStats {
1198 name: table_name.to_string(),
1199 row_count,
1200 size_bytes,
1201 column_stats,
1202 indices,
1203 last_updated: Self::now_us(),
1204 });
1205
1206 self.invalidate_plan_cache();
1208 }
1209
1210 pub fn stats_age_us(&self, table: &str) -> Option<u64> {
1212 self.stats_cache
1213 .read()
1214 .get(table)
1215 .map(|s| Self::now_us().saturating_sub(s.last_updated))
1216 }
1217
1218 fn now_us() -> u64 {
1219 SystemTime::now()
1220 .duration_since(UNIX_EPOCH)
1221 .unwrap_or_default()
1222 .as_micros() as u64
1223 }
1224}
1225
1226pub struct JoinOrderOptimizer {
1232 stats: HashMap<String, TableStats>,
1234 config: CostModelConfig,
1236}
1237
1238impl JoinOrderOptimizer {
1239 pub fn new(config: CostModelConfig) -> Self {
1240 Self {
1241 stats: HashMap::new(),
1242 config,
1243 }
1244 }
1245
1246 pub fn add_stats(&mut self, stats: TableStats) {
1248 self.stats.insert(stats.name.clone(), stats);
1249 }
1250
1251 pub fn find_optimal_order(
1256 &self,
1257 tables: &[String],
1258 join_conditions: &[(String, String, String, String)], ) -> Vec<(String, String)> {
1260 let n = tables.len();
1261 if n <= 1 {
1262 return vec![];
1263 }
1264
1265 let mut dp: HashMap<u32, (f64, Vec<(String, String)>)> = HashMap::new();
1267
1268 for (i, _table) in tables.iter().enumerate() {
1270 let mask = 1u32 << i;
1271 dp.insert(mask, (0.0, vec![]));
1272 }
1273
1274 for size in 2..=n {
1276 for mask in 0..(1u32 << n) {
1277 if mask.count_ones() != size as u32 {
1278 continue;
1279 }
1280
1281 let mut best_cost = f64::MAX;
1282 let mut best_order = vec![];
1283
1284 for sub in 1..mask {
1286 if sub & mask != sub || sub == 0 {
1287 continue;
1288 }
1289 let other = mask ^ sub;
1290 if other == 0 {
1291 continue;
1292 }
1293
1294 if !self.has_join_condition(tables, sub, other, join_conditions) {
1296 continue;
1297 }
1298
1299 if let (Some((cost1, order1)), Some((cost2, order2))) =
1300 (dp.get(&sub), dp.get(&other))
1301 {
1302 let join_cost = self.estimate_join_cost(tables, sub, other);
1303 let total_cost = cost1 + cost2 + join_cost;
1304
1305 if total_cost < best_cost {
1306 best_cost = total_cost;
1307 best_order = order1.clone();
1308 best_order.extend(order2.clone());
1309
1310 let (t1, t2) =
1312 self.get_join_tables(tables, sub, other, join_conditions);
1313 if let Some((t1, t2)) = Some((t1, t2)) {
1314 best_order.push((t1, t2));
1315 }
1316 }
1317 }
1318 }
1319
1320 if best_cost < f64::MAX {
1321 dp.insert(mask, (best_cost, best_order));
1322 }
1323 }
1324 }
1325
1326 let full_mask = (1u32 << n) - 1;
1327 dp.get(&full_mask)
1328 .map(|(_, order)| order.clone())
1329 .unwrap_or_default()
1330 }
1331
1332 fn has_join_condition(
1333 &self,
1334 tables: &[String],
1335 mask1: u32,
1336 mask2: u32,
1337 conditions: &[(String, String, String, String)],
1338 ) -> bool {
1339 for (t1, _, t2, _) in conditions {
1340 let idx1 = tables.iter().position(|t| t == t1);
1341 let idx2 = tables.iter().position(|t| t == t2);
1342
1343 if let (Some(i1), Some(i2)) = (idx1, idx2) {
1344 let in_mask1 = (mask1 >> i1) & 1 == 1;
1345 let in_mask2 = (mask2 >> i2) & 1 == 1;
1346
1347 if in_mask1 && in_mask2 {
1348 return true;
1349 }
1350 }
1351 }
1352 false
1353 }
1354
1355 fn get_join_tables(
1356 &self,
1357 tables: &[String],
1358 mask1: u32,
1359 mask2: u32,
1360 conditions: &[(String, String, String, String)],
1361 ) -> (String, String) {
1362 for (t1, _, t2, _) in conditions {
1363 let idx1 = tables.iter().position(|t| t == t1);
1364 let idx2 = tables.iter().position(|t| t == t2);
1365
1366 if let (Some(i1), Some(i2)) = (idx1, idx2) {
1367 let t1_in_mask1 = (mask1 >> i1) & 1 == 1;
1368 let t2_in_mask2 = (mask2 >> i2) & 1 == 1;
1369
1370 if t1_in_mask1 && t2_in_mask2 {
1371 return (t1.clone(), t2.clone());
1372 }
1373 }
1374 }
1375 (String::new(), String::new())
1376 }
1377
1378 fn estimate_join_cost(&self, tables: &[String], mask1: u32, mask2: u32) -> f64 {
1379 let rows1 = self.estimate_rows_for_mask(tables, mask1);
1380 let rows2 = self.estimate_rows_for_mask(tables, mask2);
1381
1382 let build_cost = rows1 as f64 * self.config.c_filter;
1385 let probe_cost = rows2 as f64 * self.config.c_filter;
1386
1387 build_cost + probe_cost
1388 }
1389
1390 fn estimate_rows_for_mask(&self, tables: &[String], mask: u32) -> u64 {
1391 let mut total = 1u64;
1392
1393 for (i, table) in tables.iter().enumerate() {
1394 if (mask >> i) & 1 == 1 {
1395 let rows = self.stats.get(table).map(|s| s.row_count).unwrap_or(1000);
1396 total = total.saturating_mul(rows);
1397 }
1398 }
1399
1400 let num_tables = mask.count_ones();
1402 if num_tables > 1 {
1403 total = (total as f64 * 0.1f64.powi(num_tables as i32 - 1)) as u64;
1404 }
1405
1406 total.max(1)
1407 }
1408}
1409
1410#[cfg(test)]
1415mod tests {
1416 use super::*;
1417
1418 fn create_test_stats() -> TableStats {
1419 let mut column_stats = HashMap::new();
1420 column_stats.insert(
1421 "id".to_string(),
1422 ColumnStats {
1423 name: "id".to_string(),
1424 distinct_count: 100000,
1425 null_count: 0,
1426 min_value: Some("1".to_string()),
1427 max_value: Some("100000".to_string()),
1428 avg_length: 8.0,
1429 mcv: vec![],
1430 histogram: None,
1431 },
1432 );
1433 column_stats.insert(
1434 "score".to_string(),
1435 ColumnStats {
1436 name: "score".to_string(),
1437 distinct_count: 100,
1438 null_count: 1000,
1439 min_value: Some("0".to_string()),
1440 max_value: Some("100".to_string()),
1441 avg_length: 8.0,
1442 mcv: vec![("50".to_string(), 0.05)],
1443 histogram: Some(Histogram {
1444 boundaries: vec![25.0, 50.0, 75.0, 100.0],
1445 counts: vec![25000, 25000, 25000, 25000],
1446 total_rows: 100000,
1447 }),
1448 },
1449 );
1450
1451 TableStats {
1452 name: "users".to_string(),
1453 row_count: 100000,
1454 size_bytes: 10_000_000, column_stats,
1456 indices: vec![
1457 IndexStats {
1458 name: "pk_users".to_string(),
1459 columns: vec!["id".to_string()],
1460 is_primary: true,
1461 is_unique: true,
1462 index_type: IndexType::BTree,
1463 leaf_pages: 1000,
1464 height: 3,
1465 avg_leaf_density: 100.0,
1466 },
1467 IndexStats {
1468 name: "idx_score".to_string(),
1469 columns: vec!["score".to_string()],
1470 is_primary: false,
1471 is_unique: false,
1472 index_type: IndexType::BTree,
1473 leaf_pages: 500,
1474 height: 2,
1475 avg_leaf_density: 200.0,
1476 },
1477 ],
1478 last_updated: 0,
1479 }
1480 }
1481
1482 #[test]
1483 fn test_selectivity_estimation() {
1484 let config = CostModelConfig::default();
1485 let optimizer = CostBasedOptimizer::new(config);
1486
1487 let stats = create_test_stats();
1488 optimizer.update_stats(stats.clone());
1489
1490 let pred = Predicate::Eq {
1492 column: "id".to_string(),
1493 value: "12345".to_string(),
1494 };
1495 let sel = optimizer.estimate_selectivity(&pred, &stats);
1496 assert!(sel < 0.001); let pred = Predicate::Gt {
1502 column: "score".to_string(),
1503 value: "75".to_string(),
1504 };
1505 let sel = optimizer.estimate_selectivity(&pred, &stats);
1506 assert!(sel > 0.4 && sel < 0.6); }
1508
1509 #[test]
1510 fn test_access_path_selection() {
1511 let config = CostModelConfig::default();
1512 let optimizer = CostBasedOptimizer::new(config);
1513
1514 let stats = create_test_stats();
1515 optimizer.update_stats(stats);
1516
1517 let pred = Predicate::Eq {
1519 column: "id".to_string(),
1520 value: "12345".to_string(),
1521 };
1522 let plan = optimizer.optimize(
1523 "users",
1524 vec!["id".to_string(), "score".to_string()],
1525 Some(pred),
1526 vec![],
1527 None,
1528 );
1529
1530 match plan {
1531 PhysicalPlan::IndexSeek { index, .. } => {
1532 assert_eq!(index, "pk_users");
1533 }
1534 _ => panic!("Expected IndexSeek for equality on primary key"),
1535 }
1536 }
1537
1538 #[test]
1539 fn test_token_budget_limit() {
1540 let config = CostModelConfig::default();
1541 let optimizer = CostBasedOptimizer::new(config).with_token_budget(2048, 25.0);
1542
1543 let plan = optimizer.optimize("users", vec!["id".to_string()], None, vec![], None);
1546
1547 match plan {
1548 PhysicalPlan::Limit { limit, .. } => {
1549 assert!(limit <= 80);
1550 }
1551 _ => panic!("Expected Limit to be injected"),
1552 }
1553 }
1554
1555 #[test]
1556 fn test_explain_output() {
1557 let config = CostModelConfig::default();
1558 let optimizer = CostBasedOptimizer::new(config);
1559
1560 let stats = create_test_stats();
1561 optimizer.update_stats(stats);
1562
1563 let plan = optimizer.optimize(
1564 "users",
1565 vec!["id".to_string(), "score".to_string()],
1566 Some(Predicate::Gt {
1567 column: "score".to_string(),
1568 value: "80".to_string(),
1569 }),
1570 vec![("score".to_string(), SortDirection::Descending)],
1571 Some(10),
1572 );
1573
1574 let explain = optimizer.explain(&plan);
1575 assert!(explain.contains("Limit"));
1576 assert!(explain.contains("Sort"));
1577 }
1578
1579 #[test]
1584 fn test_token_budget_underflow_safety() {
1585 let config = CostModelConfig::default();
1587 let optimizer = CostBasedOptimizer::new(config).with_token_budget(10, 25.0);
1588
1589 let plan = optimizer.optimize("users", vec!["id".to_string()], None, vec![], None);
1590 match plan {
1591 PhysicalPlan::Limit { limit, .. } => {
1592 assert!(limit >= 1, "Must return at least 1 row");
1593 }
1594 _ => panic!("Expected Limit"),
1595 }
1596 }
1597
1598 #[test]
1599 fn test_index_seek_derives_key_range() {
1600 let config = CostModelConfig::default();
1601 let optimizer = CostBasedOptimizer::new(config);
1602 optimizer.update_stats(create_test_stats());
1603
1604 let plan = optimizer.optimize(
1605 "users",
1606 vec!["id".to_string()],
1607 Some(Predicate::Eq {
1608 column: "id".to_string(),
1609 value: "42".to_string(),
1610 }),
1611 vec![],
1612 None,
1613 );
1614
1615 match plan {
1616 PhysicalPlan::IndexSeek { key_range, .. } => {
1617 assert!(
1618 key_range.start.is_some(),
1619 "KeyRange must derive from Eq predicate"
1620 );
1621 assert_eq!(
1622 key_range.start, key_range.end,
1623 "Eq predicate → point key range"
1624 );
1625 }
1626 _ => panic!("Expected IndexSeek"),
1627 }
1628 }
1629
1630 #[test]
1631 fn test_range_predicate_key_range() {
1632 let config = CostModelConfig::default();
1633 let optimizer = CostBasedOptimizer::new(config);
1634 optimizer.update_stats(create_test_stats());
1635
1636 let plan = optimizer.optimize(
1637 "users",
1638 vec!["score".to_string()],
1639 Some(Predicate::Between {
1640 column: "score".to_string(),
1641 min: "10".to_string(),
1642 max: "90".to_string(),
1643 }),
1644 vec![],
1645 None,
1646 );
1647
1648 match plan {
1649 PhysicalPlan::IndexSeek { key_range, .. } => {
1650 assert!(key_range.start.is_some());
1651 assert!(key_range.end.is_some());
1652 assert!(key_range.start_inclusive);
1653 assert!(key_range.end_inclusive);
1654 }
1655 _ => {} }
1657 }
1658
1659 #[test]
1660 fn test_projection_pushdown_proportional_reduction() {
1661 let config = CostModelConfig::default();
1662 let optimizer = CostBasedOptimizer::new(config);
1663 optimizer.update_stats(create_test_stats());
1664
1665 let plan_all = optimizer.optimize(
1667 "users",
1668 vec!["id".to_string(), "score".to_string()],
1669 None,
1670 vec![],
1671 Some(100),
1672 );
1673 let plan_single =
1674 optimizer.optimize("users", vec!["id".to_string()], None, vec![], Some(100));
1675
1676 let cost_all = optimizer.get_plan_cost(&plan_all);
1677 let cost_single = optimizer.get_plan_cost(&plan_single);
1678 assert!(
1680 cost_single <= cost_all,
1681 "Projection should reduce cost: {} vs {}",
1682 cost_single,
1683 cost_all
1684 );
1685 }
1686
1687 #[test]
1688 fn test_collect_stats_builds_histogram() {
1689 let config = CostModelConfig::default();
1690 let optimizer = CostBasedOptimizer::new(config);
1691
1692 let mut column_values = HashMap::new();
1693 let scores: Vec<String> = (0..100).map(|i| i.to_string()).collect();
1694 column_values.insert("score".to_string(), scores);
1695
1696 optimizer.collect_stats("test_table", 100, 10000, column_values, vec![]);
1697
1698 let stats = optimizer.get_stats("test_table").unwrap();
1699 assert_eq!(stats.row_count, 100);
1700 let score_stats = stats.column_stats.get("score").unwrap();
1701 assert_eq!(score_stats.distinct_count, 100);
1702 assert!(
1703 score_stats.histogram.is_some(),
1704 "Numeric column should get histogram"
1705 );
1706 assert!(!score_stats.mcv.is_empty(), "Should build MCV list");
1707 }
1708
1709 #[test]
1710 fn test_plan_cache_invalidation() {
1711 let config = CostModelConfig::default();
1712 let optimizer = CostBasedOptimizer::new(config);
1713
1714 let mut col = HashMap::new();
1716 col.insert("x".to_string(), vec!["1".to_string()]);
1717 optimizer.collect_stats("t", 1, 100, col.clone(), vec![]);
1718
1719 assert!(optimizer.plan_cache.read().is_empty());
1721 }
1722
1723 #[test]
1724 fn test_stats_age_tracking() {
1725 let config = CostModelConfig::default();
1726 let optimizer = CostBasedOptimizer::new(config);
1727
1728 assert!(optimizer.stats_age_us("unknown").is_none());
1729
1730 let mut col = HashMap::new();
1731 col.insert("x".to_string(), vec!["1".to_string()]);
1732 optimizer.collect_stats("t", 1, 100, col, vec![]);
1733
1734 let age = optimizer.stats_age_us("t").unwrap();
1735 assert!(age < 1_000_000, "Stats should be fresh (< 1 second old)");
1736 }
1737
1738 #[test]
1739 fn test_scan_cost_reads_all_blocks() {
1740 let config = CostModelConfig::default();
1742 let optimizer = CostBasedOptimizer::new(config.clone());
1743 let no_pred = optimizer.estimate_scan_cost(1000, 4096 * 10, None);
1744 let with_pred = optimizer.estimate_scan_cost(
1745 1000,
1746 4096 * 10,
1747 Some(&Predicate::Eq {
1748 column: "x".to_string(),
1749 value: "1".to_string(),
1750 }),
1751 );
1752 assert!(
1755 (no_pred - with_pred).abs() < 0.001,
1756 "Scan cost should not depend on predicate: {} vs {}",
1757 no_pred,
1758 with_pred
1759 );
1760 }
1761
1762 #[test]
1763 fn test_index_wins_over_scan_for_point_lookup() {
1764 let config = CostModelConfig::default();
1765 let optimizer = CostBasedOptimizer::new(config);
1766 optimizer.update_stats(create_test_stats());
1767
1768 let scan_cost = optimizer.estimate_scan_cost(100000, 10_000_000, None);
1769
1770 let pk_index = &create_test_stats().indices[0]; let index_cost = optimizer.estimate_index_cost(pk_index, 100000, 0.00001);
1773
1774 assert!(
1775 index_cost < scan_cost * 0.1,
1776 "Index point lookup ({:.2}) should be <10% of scan cost ({:.2})",
1777 index_cost,
1778 scan_cost
1779 );
1780 }
1781
1782 #[test]
1783 fn test_no_stats_defaults_to_scan() {
1784 let config = CostModelConfig::default();
1785 let optimizer = CostBasedOptimizer::new(config);
1786 let plan = optimizer.optimize(
1788 "unknown_table",
1789 vec!["col1".to_string()],
1790 Some(Predicate::Eq {
1791 column: "col1".to_string(),
1792 value: "x".to_string(),
1793 }),
1794 vec![],
1795 None,
1796 );
1797 match plan {
1799 PhysicalPlan::TableScan { estimated_rows, .. } => {
1800 assert!(estimated_rows > 0, "Default row estimate must be positive");
1801 }
1802 PhysicalPlan::IndexSeek { .. } => {} _ => panic!("Expected TableScan or IndexSeek for unknown table"),
1804 }
1805 }
1806
1807 #[test]
1808 fn test_compound_predicate_selectivity() {
1809 let stats = create_test_stats();
1810 let config = CostModelConfig::default();
1811 let optimizer = CostBasedOptimizer::new(config);
1812
1813 let and_pred = Predicate::And(
1815 Box::new(Predicate::Eq {
1816 column: "id".to_string(),
1817 value: "1".to_string(),
1818 }),
1819 Box::new(Predicate::IsNotNull {
1820 column: "score".to_string(),
1821 }),
1822 );
1823 let sel = optimizer.estimate_selectivity(&and_pred, &stats);
1824 let eq_sel = optimizer.estimate_selectivity(
1825 &Predicate::Eq {
1826 column: "id".to_string(),
1827 value: "1".to_string(),
1828 },
1829 &stats,
1830 );
1831 assert!(sel < eq_sel, "AND must be more selective than either child");
1832
1833 let or_pred = Predicate::Or(
1835 Box::new(Predicate::Eq {
1836 column: "id".to_string(),
1837 value: "1".to_string(),
1838 }),
1839 Box::new(Predicate::Eq {
1840 column: "id".to_string(),
1841 value: "2".to_string(),
1842 }),
1843 );
1844 let sel = optimizer.estimate_selectivity(&or_pred, &stats);
1845 assert!(sel > eq_sel, "OR must be less selective than either child");
1846 assert!(sel <= 1.0, "Selectivity must be <= 1.0");
1847 }
1848
1849 #[test]
1850 fn test_join_order_optimizer() {
1851 let mut join_opt = JoinOrderOptimizer::new(CostModelConfig::default());
1852 join_opt.add_stats(TableStats {
1853 name: "orders".to_string(),
1854 row_count: 1000000,
1855 size_bytes: 100_000_000,
1856 column_stats: HashMap::new(),
1857 indices: vec![],
1858 last_updated: 0,
1859 });
1860 join_opt.add_stats(TableStats {
1861 name: "users".to_string(),
1862 row_count: 10000,
1863 size_bytes: 1_000_000,
1864 column_stats: HashMap::new(),
1865 indices: vec![],
1866 last_updated: 0,
1867 });
1868
1869 let order = join_opt.find_optimal_order(
1870 &["orders".to_string(), "users".to_string()],
1871 &[(
1872 "orders".to_string(),
1873 "user_id".to_string(),
1874 "users".to_string(),
1875 "id".to_string(),
1876 )],
1877 );
1878 assert!(!order.is_empty(), "Should find a join order");
1879 }
1880}