1use parking_lot::RwLock;
44use std::collections::{HashMap, HashSet};
45use std::sync::Arc;
46
47#[derive(Debug, Clone)]
53pub struct CostModelConfig {
54 pub c_seq: f64,
56 pub c_random: f64,
58 pub c_filter: f64,
60 pub c_compare: f64,
62 pub block_size: usize,
64 pub btree_fanout: usize,
66 pub memory_bandwidth: f64,
68}
69
70impl Default for CostModelConfig {
71 fn default() -> Self {
72 Self {
73 c_seq: 0.1, c_random: 5.0, c_filter: 0.001, c_compare: 0.0001, block_size: 4096, btree_fanout: 100, memory_bandwidth: 10000.0, }
81 }
82}
83
84#[derive(Debug, Clone)]
90pub struct TableStats {
91 pub name: String,
93 pub row_count: u64,
95 pub size_bytes: u64,
97 pub column_stats: HashMap<String, ColumnStats>,
99 pub indices: Vec<IndexStats>,
101 pub last_updated: u64,
103}
104
105#[derive(Debug, Clone)]
107pub struct ColumnStats {
108 pub name: String,
110 pub distinct_count: u64,
112 pub null_count: u64,
114 pub min_value: Option<String>,
116 pub max_value: Option<String>,
118 pub avg_length: f64,
120 pub mcv: Vec<(String, f64)>,
122 pub histogram: Option<Histogram>,
124}
125
126#[derive(Debug, Clone)]
128pub struct Histogram {
129 pub boundaries: Vec<f64>,
131 pub counts: Vec<u64>,
133 pub total_rows: u64,
135}
136
137impl Histogram {
138 pub fn estimate_range_selectivity(&self, min: Option<f64>, max: Option<f64>) -> f64 {
140 if self.total_rows == 0 {
141 return 0.5; }
143
144 let mut selected_rows = 0u64;
145
146 for (i, &count) in self.counts.iter().enumerate() {
147 let bucket_min = if i == 0 {
148 f64::NEG_INFINITY
149 } else {
150 self.boundaries[i - 1]
151 };
152 let bucket_max = if i == self.boundaries.len() {
153 f64::INFINITY
154 } else {
155 self.boundaries[i]
156 };
157
158 let overlaps = match (min, max) {
159 (Some(min_val), Some(max_val)) => bucket_max >= min_val && bucket_min <= max_val,
160 (Some(min_val), None) => bucket_max >= min_val,
161 (None, Some(max_val)) => bucket_min <= max_val,
162 (None, None) => true,
163 };
164
165 if overlaps {
166 selected_rows += count;
167 }
168 }
169
170 selected_rows as f64 / self.total_rows as f64
171 }
172}
173
174#[derive(Debug, Clone)]
176pub struct IndexStats {
177 pub name: String,
179 pub columns: Vec<String>,
181 pub is_primary: bool,
183 pub is_unique: bool,
185 pub index_type: IndexType,
187 pub leaf_pages: u64,
189 pub height: u32,
191 pub avg_leaf_density: f64,
193}
194
195#[derive(Debug, Clone, Copy, PartialEq, Eq)]
197pub enum IndexType {
198 BTree,
199 Hash,
200 LSM,
201 Learned,
202 Vector,
203 Bloom,
204}
205
206#[derive(Debug, Clone)]
212pub enum Predicate {
213 Eq { column: String, value: String },
215 Ne { column: String, value: String },
217 Lt { column: String, value: String },
219 Le { column: String, value: String },
221 Gt { column: String, value: String },
223 Ge { column: String, value: String },
225 Between {
227 column: String,
228 min: String,
229 max: String,
230 },
231 In { column: String, values: Vec<String> },
233 Like { column: String, pattern: String },
235 IsNull { column: String },
237 IsNotNull { column: String },
239 And(Box<Predicate>, Box<Predicate>),
241 Or(Box<Predicate>, Box<Predicate>),
243 Not(Box<Predicate>),
245}
246
247impl Predicate {
248 pub fn referenced_columns(&self) -> HashSet<String> {
250 let mut cols = HashSet::new();
251 self.collect_columns(&mut cols);
252 cols
253 }
254
255 fn collect_columns(&self, cols: &mut HashSet<String>) {
256 match self {
257 Self::Eq { column, .. }
258 | Self::Ne { column, .. }
259 | Self::Lt { column, .. }
260 | Self::Le { column, .. }
261 | Self::Gt { column, .. }
262 | Self::Ge { column, .. }
263 | Self::Between { column, .. }
264 | Self::In { column, .. }
265 | Self::Like { column, .. }
266 | Self::IsNull { column }
267 | Self::IsNotNull { column } => {
268 cols.insert(column.clone());
269 }
270 Self::And(left, right) | Self::Or(left, right) => {
271 left.collect_columns(cols);
272 right.collect_columns(cols);
273 }
274 Self::Not(inner) => inner.collect_columns(cols),
275 }
276 }
277}
278
279#[derive(Debug, Clone)]
285pub enum PhysicalPlan {
286 TableScan {
288 table: String,
289 columns: Vec<String>,
290 predicate: Option<Box<Predicate>>,
291 estimated_rows: u64,
292 estimated_cost: f64,
293 },
294 IndexSeek {
296 table: String,
297 index: String,
298 columns: Vec<String>,
299 key_range: KeyRange,
300 predicate: Option<Box<Predicate>>,
301 estimated_rows: u64,
302 estimated_cost: f64,
303 },
304 Filter {
306 input: Box<PhysicalPlan>,
307 predicate: Predicate,
308 estimated_rows: u64,
309 estimated_cost: f64,
310 },
311 Project {
313 input: Box<PhysicalPlan>,
314 columns: Vec<String>,
315 estimated_cost: f64,
316 },
317 Sort {
319 input: Box<PhysicalPlan>,
320 order_by: Vec<(String, SortDirection)>,
321 estimated_cost: f64,
322 },
323 Limit {
325 input: Box<PhysicalPlan>,
326 limit: u64,
327 offset: u64,
328 estimated_cost: f64,
329 },
330 NestedLoopJoin {
332 outer: Box<PhysicalPlan>,
333 inner: Box<PhysicalPlan>,
334 condition: Predicate,
335 join_type: JoinType,
336 estimated_rows: u64,
337 estimated_cost: f64,
338 },
339 HashJoin {
341 build: Box<PhysicalPlan>,
342 probe: Box<PhysicalPlan>,
343 build_keys: Vec<String>,
344 probe_keys: Vec<String>,
345 join_type: JoinType,
346 estimated_rows: u64,
347 estimated_cost: f64,
348 },
349 MergeJoin {
351 left: Box<PhysicalPlan>,
352 right: Box<PhysicalPlan>,
353 left_keys: Vec<String>,
354 right_keys: Vec<String>,
355 join_type: JoinType,
356 estimated_rows: u64,
357 estimated_cost: f64,
358 },
359 Aggregate {
361 input: Box<PhysicalPlan>,
362 group_by: Vec<String>,
363 aggregates: Vec<AggregateExpr>,
364 estimated_rows: u64,
365 estimated_cost: f64,
366 },
367}
368
369#[derive(Debug, Clone)]
371pub struct KeyRange {
372 pub start: Option<Vec<u8>>,
373 pub end: Option<Vec<u8>>,
374 pub start_inclusive: bool,
375 pub end_inclusive: bool,
376}
377
378impl KeyRange {
379 pub fn all() -> Self {
380 Self {
381 start: None,
382 end: None,
383 start_inclusive: true,
384 end_inclusive: true,
385 }
386 }
387
388 pub fn point(key: Vec<u8>) -> Self {
389 Self {
390 start: Some(key.clone()),
391 end: Some(key),
392 start_inclusive: true,
393 end_inclusive: true,
394 }
395 }
396
397 pub fn range(start: Option<Vec<u8>>, end: Option<Vec<u8>>, inclusive: bool) -> Self {
398 Self {
399 start,
400 end,
401 start_inclusive: inclusive,
402 end_inclusive: inclusive,
403 }
404 }
405}
406
407#[derive(Debug, Clone, Copy, PartialEq, Eq)]
409pub enum SortDirection {
410 Ascending,
411 Descending,
412}
413
414#[derive(Debug, Clone, Copy, PartialEq, Eq)]
416pub enum JoinType {
417 Inner,
418 Left,
419 Right,
420 Full,
421 Cross,
422}
423
424#[derive(Debug, Clone)]
426pub struct AggregateExpr {
427 pub function: AggregateFunction,
428 pub column: Option<String>,
429 pub alias: String,
430}
431
432#[derive(Debug, Clone, Copy, PartialEq, Eq)]
434pub enum AggregateFunction {
435 Count,
436 Sum,
437 Avg,
438 Min,
439 Max,
440 CountDistinct,
441}
442
443pub struct CostBasedOptimizer {
449 config: CostModelConfig,
451 stats_cache: Arc<RwLock<HashMap<String, TableStats>>>,
453 token_budget: Option<u64>,
455 tokens_per_row: f64,
457}
458
459impl CostBasedOptimizer {
460 pub fn new(config: CostModelConfig) -> Self {
461 Self {
462 config,
463 stats_cache: Arc::new(RwLock::new(HashMap::new())),
464 token_budget: None,
465 tokens_per_row: 25.0, }
467 }
468
469 pub fn with_token_budget(mut self, budget: u64, tokens_per_row: f64) -> Self {
471 self.token_budget = Some(budget);
472 self.tokens_per_row = tokens_per_row;
473 self
474 }
475
476 pub fn update_stats(&self, stats: TableStats) {
478 self.stats_cache.write().insert(stats.name.clone(), stats);
479 }
480
481 pub fn get_stats(&self, table: &str) -> Option<TableStats> {
483 self.stats_cache.read().get(table).cloned()
484 }
485
486 pub fn optimize(
488 &self,
489 table: &str,
490 columns: Vec<String>,
491 predicate: Option<Predicate>,
492 order_by: Vec<(String, SortDirection)>,
493 limit: Option<u64>,
494 ) -> PhysicalPlan {
495 let stats = self.get_stats(table);
496
497 let effective_limit = self.calculate_token_limit(limit);
499
500 let mut plan = self.choose_access_path(table, &columns, predicate.as_ref(), &stats);
502
503 plan = self.apply_projection_pushdown(plan, columns.clone());
505
506 if !order_by.is_empty() {
508 plan = self.add_sort(plan, order_by, &stats);
509 }
510
511 if let Some(lim) = effective_limit {
513 plan = PhysicalPlan::Limit {
514 estimated_cost: 0.0,
515 input: Box::new(plan),
516 limit: lim,
517 offset: 0,
518 };
519 }
520
521 plan
522 }
523
524 fn calculate_token_limit(&self, user_limit: Option<u64>) -> Option<u64> {
526 match (self.token_budget, user_limit) {
527 (Some(budget), Some(limit)) => {
528 let header_tokens = 50u64;
529 let max_rows = ((budget - header_tokens) as f64 / self.tokens_per_row) as u64;
530 Some(limit.min(max_rows))
531 }
532 (Some(budget), None) => {
533 let header_tokens = 50u64;
534 let max_rows = ((budget - header_tokens) as f64 / self.tokens_per_row) as u64;
535 Some(max_rows)
536 }
537 (None, limit) => limit,
538 }
539 }
540
541 fn choose_access_path(
543 &self,
544 table: &str,
545 columns: &[String],
546 predicate: Option<&Predicate>,
547 stats: &Option<TableStats>,
548 ) -> PhysicalPlan {
549 let row_count = stats.as_ref().map(|s| s.row_count).unwrap_or(10000);
550 let size_bytes = stats
551 .as_ref()
552 .map(|s| s.size_bytes)
553 .unwrap_or(row_count * 100);
554
555 let scan_cost = self.estimate_scan_cost(row_count, size_bytes, predicate);
557
558 let mut best_index_cost = f64::MAX;
560 let mut best_index: Option<&IndexStats> = None;
561
562 if let Some(table_stats) = stats.as_ref()
563 && let Some(pred) = predicate
564 {
565 let pred_columns = pred.referenced_columns();
566
567 for index in &table_stats.indices {
568 if self.index_covers_predicate(index, &pred_columns) {
569 let selectivity = self.estimate_selectivity(pred, table_stats);
570 let index_cost = self.estimate_index_cost(index, row_count, selectivity);
571
572 if index_cost < best_index_cost {
573 best_index_cost = index_cost;
574 best_index = Some(index);
575 }
576 }
577 }
578 }
579
580 if best_index_cost < scan_cost {
582 let index = best_index.unwrap();
583 let selectivity = predicate
584 .map(|p| self.estimate_selectivity(p, stats.as_ref().unwrap()))
585 .unwrap_or(1.0);
586
587 PhysicalPlan::IndexSeek {
588 table: table.to_string(),
589 index: index.name.clone(),
590 columns: columns.to_vec(),
591 key_range: KeyRange::all(), predicate: predicate.map(|p| Box::new(p.clone())),
593 estimated_rows: (row_count as f64 * selectivity) as u64,
594 estimated_cost: best_index_cost,
595 }
596 } else {
597 PhysicalPlan::TableScan {
598 table: table.to_string(),
599 columns: columns.to_vec(),
600 predicate: predicate.map(|p| Box::new(p.clone())),
601 estimated_rows: row_count,
602 estimated_cost: scan_cost,
603 }
604 }
605 }
606
607 fn index_covers_predicate(&self, index: &IndexStats, pred_columns: &HashSet<String>) -> bool {
609 if let Some(first_col) = index.columns.first() {
611 pred_columns.contains(first_col)
612 } else {
613 false
614 }
615 }
616
617 fn estimate_scan_cost(
619 &self,
620 row_count: u64,
621 size_bytes: u64,
622 predicate: Option<&Predicate>,
623 ) -> f64 {
624 let blocks = (size_bytes as f64 / self.config.block_size as f64).ceil() as u64;
625
626 let io_cost = blocks as f64 * self.config.c_seq;
628
629 let selectivity = predicate.map(|_| 0.1).unwrap_or(1.0);
631 let cpu_cost = row_count as f64 * self.config.c_filter * selectivity;
632
633 io_cost + cpu_cost
634 }
635
636 fn estimate_index_cost(&self, index: &IndexStats, total_rows: u64, selectivity: f64) -> f64 {
640 let tree_cost = index.height as f64 * self.config.c_random;
642
643 let matching_rows = (total_rows as f64 * selectivity) as u64;
645 let leaf_pages_scanned = (matching_rows as f64 / index.avg_leaf_density).ceil() as u64;
646 let leaf_cost = leaf_pages_scanned as f64 * self.config.c_seq;
647
648 let fetch_cost = if index.is_primary {
650 0.0 } else {
652 matching_rows.min(1000) as f64 * self.config.c_random * 0.1 };
654
655 tree_cost + leaf_cost + fetch_cost
656 }
657
658 #[allow(clippy::only_used_in_recursion)]
660 fn estimate_selectivity(&self, predicate: &Predicate, stats: &TableStats) -> f64 {
661 match predicate {
662 Predicate::Eq { column, value } => {
663 if let Some(col_stats) = stats.column_stats.get(column) {
664 for (mcv_val, freq) in &col_stats.mcv {
666 if mcv_val == value {
667 return *freq;
668 }
669 }
670 1.0 / col_stats.distinct_count.max(1) as f64
672 } else {
673 0.1 }
675 }
676 Predicate::Ne { .. } => 0.9, Predicate::Lt { column, value }
678 | Predicate::Le { column, value }
679 | Predicate::Gt { column, value }
680 | Predicate::Ge { column, value } => {
681 if let Some(col_stats) = stats.column_stats.get(column) {
682 if let Some(ref hist) = col_stats.histogram {
683 let val: f64 = value.parse().unwrap_or(0.0);
684 match predicate {
685 Predicate::Lt { .. } | Predicate::Le { .. } => {
686 hist.estimate_range_selectivity(None, Some(val))
687 }
688 _ => hist.estimate_range_selectivity(Some(val), None),
689 }
690 } else {
691 0.25 }
693 } else {
694 0.25
695 }
696 }
697 Predicate::Between { column, min, max } => {
698 if let Some(col_stats) = stats.column_stats.get(column) {
699 if let Some(ref hist) = col_stats.histogram {
700 let min_val: f64 = min.parse().unwrap_or(0.0);
701 let max_val: f64 = max.parse().unwrap_or(f64::MAX);
702 hist.estimate_range_selectivity(Some(min_val), Some(max_val))
703 } else {
704 0.2
705 }
706 } else {
707 0.2
708 }
709 }
710 Predicate::In { column, values } => {
711 if let Some(col_stats) = stats.column_stats.get(column) {
712 (values.len() as f64 / col_stats.distinct_count.max(1) as f64).min(1.0)
713 } else {
714 (values.len() as f64 * 0.1).min(0.5)
715 }
716 }
717 Predicate::Like { .. } => 0.15, Predicate::IsNull { column } => {
719 if let Some(col_stats) = stats.column_stats.get(column) {
720 col_stats.null_count as f64 / stats.row_count.max(1) as f64
721 } else {
722 0.01
723 }
724 }
725 Predicate::IsNotNull { column } => {
726 if let Some(col_stats) = stats.column_stats.get(column) {
727 1.0 - (col_stats.null_count as f64 / stats.row_count.max(1) as f64)
728 } else {
729 0.99
730 }
731 }
732 Predicate::And(left, right) => {
733 self.estimate_selectivity(left, stats) * self.estimate_selectivity(right, stats)
735 }
736 Predicate::Or(left, right) => {
737 let s1 = self.estimate_selectivity(left, stats);
738 let s2 = self.estimate_selectivity(right, stats);
739 (s1 + s2 - s1 * s2).min(1.0)
741 }
742 Predicate::Not(inner) => 1.0 - self.estimate_selectivity(inner, stats),
743 }
744 }
745
746 fn apply_projection_pushdown(&self, plan: PhysicalPlan, columns: Vec<String>) -> PhysicalPlan {
748 match plan {
750 PhysicalPlan::TableScan {
751 table,
752 predicate,
753 estimated_rows,
754 estimated_cost,
755 ..
756 } => {
757 PhysicalPlan::TableScan {
758 table,
759 columns, predicate,
761 estimated_rows,
762 estimated_cost: estimated_cost * 0.2, }
764 }
765 PhysicalPlan::IndexSeek {
766 table,
767 index,
768 key_range,
769 predicate,
770 estimated_rows,
771 estimated_cost,
772 ..
773 } => {
774 PhysicalPlan::IndexSeek {
775 table,
776 index,
777 columns, key_range,
779 predicate,
780 estimated_rows,
781 estimated_cost,
782 }
783 }
784 other => PhysicalPlan::Project {
785 input: Box::new(other),
786 columns,
787 estimated_cost: 0.0,
788 },
789 }
790 }
791
792 fn add_sort(
794 &self,
795 plan: PhysicalPlan,
796 order_by: Vec<(String, SortDirection)>,
797 _stats: &Option<TableStats>,
798 ) -> PhysicalPlan {
799 let estimated_rows = self.get_plan_rows(&plan);
800 let sort_cost = if estimated_rows > 0 {
801 estimated_rows as f64 * (estimated_rows as f64).log2() * self.config.c_compare
802 } else {
803 0.0
804 };
805
806 PhysicalPlan::Sort {
807 input: Box::new(plan),
808 order_by,
809 estimated_cost: sort_cost,
810 }
811 }
812
813 #[allow(clippy::only_used_in_recursion)]
815 fn get_plan_rows(&self, plan: &PhysicalPlan) -> u64 {
816 match plan {
817 PhysicalPlan::TableScan { estimated_rows, .. }
818 | PhysicalPlan::IndexSeek { estimated_rows, .. }
819 | PhysicalPlan::Filter { estimated_rows, .. }
820 | PhysicalPlan::Aggregate { estimated_rows, .. }
821 | PhysicalPlan::NestedLoopJoin { estimated_rows, .. }
822 | PhysicalPlan::HashJoin { estimated_rows, .. }
823 | PhysicalPlan::MergeJoin { estimated_rows, .. } => *estimated_rows,
824 PhysicalPlan::Project { input, .. } | PhysicalPlan::Sort { input, .. } => {
825 self.get_plan_rows(input)
826 }
827 PhysicalPlan::Limit { limit, .. } => *limit,
828 }
829 }
830
831 #[allow(clippy::only_used_in_recursion)]
833 pub fn get_plan_cost(&self, plan: &PhysicalPlan) -> f64 {
834 match plan {
835 PhysicalPlan::TableScan { estimated_cost, .. } => *estimated_cost,
836 PhysicalPlan::IndexSeek { estimated_cost, .. } => *estimated_cost,
837 PhysicalPlan::Filter {
838 estimated_cost,
839 input,
840 ..
841 } => *estimated_cost + self.get_plan_cost(input),
842 PhysicalPlan::Project {
843 estimated_cost,
844 input,
845 ..
846 } => *estimated_cost + self.get_plan_cost(input),
847 PhysicalPlan::Sort {
848 estimated_cost,
849 input,
850 ..
851 } => *estimated_cost + self.get_plan_cost(input),
852 PhysicalPlan::Limit {
853 estimated_cost,
854 input,
855 ..
856 } => *estimated_cost + self.get_plan_cost(input),
857 PhysicalPlan::NestedLoopJoin {
858 estimated_cost,
859 outer,
860 inner,
861 ..
862 } => *estimated_cost + self.get_plan_cost(outer) + self.get_plan_cost(inner),
863 PhysicalPlan::HashJoin {
864 estimated_cost,
865 build,
866 probe,
867 ..
868 } => *estimated_cost + self.get_plan_cost(build) + self.get_plan_cost(probe),
869 PhysicalPlan::MergeJoin {
870 estimated_cost,
871 left,
872 right,
873 ..
874 } => *estimated_cost + self.get_plan_cost(left) + self.get_plan_cost(right),
875 PhysicalPlan::Aggregate {
876 estimated_cost,
877 input,
878 ..
879 } => *estimated_cost + self.get_plan_cost(input),
880 }
881 }
882
883 pub fn explain(&self, plan: &PhysicalPlan) -> String {
885 self.explain_impl(plan, 0)
886 }
887
888 fn explain_impl(&self, plan: &PhysicalPlan, indent: usize) -> String {
889 let prefix = " ".repeat(indent);
890 let cost = self.get_plan_cost(plan);
891
892 match plan {
893 PhysicalPlan::TableScan {
894 table,
895 columns,
896 estimated_rows,
897 ..
898 } => {
899 format!(
900 "{}TableScan [table={}, columns={:?}, rows={}, cost={:.2}ms]",
901 prefix, table, columns, estimated_rows, cost
902 )
903 }
904 PhysicalPlan::IndexSeek {
905 table,
906 index,
907 columns,
908 estimated_rows,
909 ..
910 } => {
911 format!(
912 "{}IndexSeek [table={}, index={}, columns={:?}, rows={}, cost={:.2}ms]",
913 prefix, table, index, columns, estimated_rows, cost
914 )
915 }
916 PhysicalPlan::Filter {
917 input,
918 estimated_rows,
919 ..
920 } => {
921 format!(
922 "{}Filter [rows={}, cost={:.2}ms]\n{}",
923 prefix,
924 estimated_rows,
925 cost,
926 self.explain_impl(input, indent + 1)
927 )
928 }
929 PhysicalPlan::Project { input, columns, .. } => {
930 format!(
931 "{}Project [columns={:?}, cost={:.2}ms]\n{}",
932 prefix,
933 columns,
934 cost,
935 self.explain_impl(input, indent + 1)
936 )
937 }
938 PhysicalPlan::Sort {
939 input, order_by, ..
940 } => {
941 let order: Vec<_> = order_by
942 .iter()
943 .map(|(c, d)| format!("{} {:?}", c, d))
944 .collect();
945 format!(
946 "{}Sort [order={:?}, cost={:.2}ms]\n{}",
947 prefix,
948 order,
949 cost,
950 self.explain_impl(input, indent + 1)
951 )
952 }
953 PhysicalPlan::Limit {
954 input,
955 limit,
956 offset,
957 ..
958 } => {
959 format!(
960 "{}Limit [limit={}, offset={}, cost={:.2}ms]\n{}",
961 prefix,
962 limit,
963 offset,
964 cost,
965 self.explain_impl(input, indent + 1)
966 )
967 }
968 PhysicalPlan::HashJoin {
969 build,
970 probe,
971 join_type,
972 estimated_rows,
973 ..
974 } => {
975 format!(
976 "{}HashJoin [type={:?}, rows={}, cost={:.2}ms]\n{}\n{}",
977 prefix,
978 join_type,
979 estimated_rows,
980 cost,
981 self.explain_impl(build, indent + 1),
982 self.explain_impl(probe, indent + 1)
983 )
984 }
985 PhysicalPlan::MergeJoin {
986 left,
987 right,
988 join_type,
989 estimated_rows,
990 ..
991 } => {
992 format!(
993 "{}MergeJoin [type={:?}, rows={}, cost={:.2}ms]\n{}\n{}",
994 prefix,
995 join_type,
996 estimated_rows,
997 cost,
998 self.explain_impl(left, indent + 1),
999 self.explain_impl(right, indent + 1)
1000 )
1001 }
1002 PhysicalPlan::NestedLoopJoin {
1003 outer,
1004 inner,
1005 join_type,
1006 estimated_rows,
1007 ..
1008 } => {
1009 format!(
1010 "{}NestedLoopJoin [type={:?}, rows={}, cost={:.2}ms]\n{}\n{}",
1011 prefix,
1012 join_type,
1013 estimated_rows,
1014 cost,
1015 self.explain_impl(outer, indent + 1),
1016 self.explain_impl(inner, indent + 1)
1017 )
1018 }
1019 PhysicalPlan::Aggregate {
1020 input,
1021 group_by,
1022 aggregates,
1023 estimated_rows,
1024 ..
1025 } => {
1026 let aggs: Vec<_> = aggregates
1027 .iter()
1028 .map(|a| format!("{:?}({})", a.function, a.column.as_deref().unwrap_or("*")))
1029 .collect();
1030 format!(
1031 "{}Aggregate [group_by={:?}, aggs={:?}, rows={}, cost={:.2}ms]\n{}",
1032 prefix,
1033 group_by,
1034 aggs,
1035 estimated_rows,
1036 cost,
1037 self.explain_impl(input, indent + 1)
1038 )
1039 }
1040 }
1041 }
1042}
1043
1044pub struct JoinOrderOptimizer {
1050 stats: HashMap<String, TableStats>,
1052 config: CostModelConfig,
1054}
1055
1056impl JoinOrderOptimizer {
1057 pub fn new(config: CostModelConfig) -> Self {
1058 Self {
1059 stats: HashMap::new(),
1060 config,
1061 }
1062 }
1063
1064 pub fn add_stats(&mut self, stats: TableStats) {
1066 self.stats.insert(stats.name.clone(), stats);
1067 }
1068
1069 pub fn find_optimal_order(
1074 &self,
1075 tables: &[String],
1076 join_conditions: &[(String, String, String, String)], ) -> Vec<(String, String)> {
1078 let n = tables.len();
1079 if n <= 1 {
1080 return vec![];
1081 }
1082
1083 let mut dp: HashMap<u32, (f64, Vec<(String, String)>)> = HashMap::new();
1085
1086 for (i, _table) in tables.iter().enumerate() {
1088 let mask = 1u32 << i;
1089 dp.insert(mask, (0.0, vec![]));
1090 }
1091
1092 for size in 2..=n {
1094 for mask in 0..(1u32 << n) {
1095 if mask.count_ones() != size as u32 {
1096 continue;
1097 }
1098
1099 let mut best_cost = f64::MAX;
1100 let mut best_order = vec![];
1101
1102 for sub in 1..mask {
1104 if sub & mask != sub || sub == 0 {
1105 continue;
1106 }
1107 let other = mask ^ sub;
1108 if other == 0 {
1109 continue;
1110 }
1111
1112 if !self.has_join_condition(tables, sub, other, join_conditions) {
1114 continue;
1115 }
1116
1117 if let (Some((cost1, order1)), Some((cost2, order2))) =
1118 (dp.get(&sub), dp.get(&other))
1119 {
1120 let join_cost = self.estimate_join_cost(tables, sub, other);
1121 let total_cost = cost1 + cost2 + join_cost;
1122
1123 if total_cost < best_cost {
1124 best_cost = total_cost;
1125 best_order = order1.clone();
1126 best_order.extend(order2.clone());
1127
1128 let (t1, t2) =
1130 self.get_join_tables(tables, sub, other, join_conditions);
1131 if let Some((t1, t2)) = Some((t1, t2)) {
1132 best_order.push((t1, t2));
1133 }
1134 }
1135 }
1136 }
1137
1138 if best_cost < f64::MAX {
1139 dp.insert(mask, (best_cost, best_order));
1140 }
1141 }
1142 }
1143
1144 let full_mask = (1u32 << n) - 1;
1145 dp.get(&full_mask)
1146 .map(|(_, order)| order.clone())
1147 .unwrap_or_default()
1148 }
1149
1150 fn has_join_condition(
1151 &self,
1152 tables: &[String],
1153 mask1: u32,
1154 mask2: u32,
1155 conditions: &[(String, String, String, String)],
1156 ) -> bool {
1157 for (t1, _, t2, _) in conditions {
1158 let idx1 = tables.iter().position(|t| t == t1);
1159 let idx2 = tables.iter().position(|t| t == t2);
1160
1161 if let (Some(i1), Some(i2)) = (idx1, idx2) {
1162 let in_mask1 = (mask1 >> i1) & 1 == 1;
1163 let in_mask2 = (mask2 >> i2) & 1 == 1;
1164
1165 if in_mask1 && in_mask2 {
1166 return true;
1167 }
1168 }
1169 }
1170 false
1171 }
1172
1173 fn get_join_tables(
1174 &self,
1175 tables: &[String],
1176 mask1: u32,
1177 mask2: u32,
1178 conditions: &[(String, String, String, String)],
1179 ) -> (String, String) {
1180 for (t1, _, t2, _) in conditions {
1181 let idx1 = tables.iter().position(|t| t == t1);
1182 let idx2 = tables.iter().position(|t| t == t2);
1183
1184 if let (Some(i1), Some(i2)) = (idx1, idx2) {
1185 let t1_in_mask1 = (mask1 >> i1) & 1 == 1;
1186 let t2_in_mask2 = (mask2 >> i2) & 1 == 1;
1187
1188 if t1_in_mask1 && t2_in_mask2 {
1189 return (t1.clone(), t2.clone());
1190 }
1191 }
1192 }
1193 (String::new(), String::new())
1194 }
1195
1196 fn estimate_join_cost(&self, tables: &[String], mask1: u32, mask2: u32) -> f64 {
1197 let rows1 = self.estimate_rows_for_mask(tables, mask1);
1198 let rows2 = self.estimate_rows_for_mask(tables, mask2);
1199
1200 let build_cost = rows1 as f64 * self.config.c_filter;
1203 let probe_cost = rows2 as f64 * self.config.c_filter;
1204
1205 build_cost + probe_cost
1206 }
1207
1208 fn estimate_rows_for_mask(&self, tables: &[String], mask: u32) -> u64 {
1209 let mut total = 1u64;
1210
1211 for (i, table) in tables.iter().enumerate() {
1212 if (mask >> i) & 1 == 1 {
1213 let rows = self.stats.get(table).map(|s| s.row_count).unwrap_or(1000);
1214 total = total.saturating_mul(rows);
1215 }
1216 }
1217
1218 let num_tables = mask.count_ones();
1220 if num_tables > 1 {
1221 total = (total as f64 * 0.1f64.powi(num_tables as i32 - 1)) as u64;
1222 }
1223
1224 total.max(1)
1225 }
1226}
1227
1228#[cfg(test)]
1233mod tests {
1234 use super::*;
1235
1236 fn create_test_stats() -> TableStats {
1237 let mut column_stats = HashMap::new();
1238 column_stats.insert(
1239 "id".to_string(),
1240 ColumnStats {
1241 name: "id".to_string(),
1242 distinct_count: 100000,
1243 null_count: 0,
1244 min_value: Some("1".to_string()),
1245 max_value: Some("100000".to_string()),
1246 avg_length: 8.0,
1247 mcv: vec![],
1248 histogram: None,
1249 },
1250 );
1251 column_stats.insert(
1252 "score".to_string(),
1253 ColumnStats {
1254 name: "score".to_string(),
1255 distinct_count: 100,
1256 null_count: 1000,
1257 min_value: Some("0".to_string()),
1258 max_value: Some("100".to_string()),
1259 avg_length: 8.0,
1260 mcv: vec![("50".to_string(), 0.05)],
1261 histogram: Some(Histogram {
1262 boundaries: vec![25.0, 50.0, 75.0, 100.0],
1263 counts: vec![25000, 25000, 25000, 25000],
1264 total_rows: 100000,
1265 }),
1266 },
1267 );
1268
1269 TableStats {
1270 name: "users".to_string(),
1271 row_count: 100000,
1272 size_bytes: 10_000_000, column_stats,
1274 indices: vec![
1275 IndexStats {
1276 name: "pk_users".to_string(),
1277 columns: vec!["id".to_string()],
1278 is_primary: true,
1279 is_unique: true,
1280 index_type: IndexType::BTree,
1281 leaf_pages: 1000,
1282 height: 3,
1283 avg_leaf_density: 100.0,
1284 },
1285 IndexStats {
1286 name: "idx_score".to_string(),
1287 columns: vec!["score".to_string()],
1288 is_primary: false,
1289 is_unique: false,
1290 index_type: IndexType::BTree,
1291 leaf_pages: 500,
1292 height: 2,
1293 avg_leaf_density: 200.0,
1294 },
1295 ],
1296 last_updated: 0,
1297 }
1298 }
1299
1300 #[test]
1301 fn test_selectivity_estimation() {
1302 let config = CostModelConfig::default();
1303 let optimizer = CostBasedOptimizer::new(config);
1304
1305 let stats = create_test_stats();
1306 optimizer.update_stats(stats.clone());
1307
1308 let pred = Predicate::Eq {
1310 column: "id".to_string(),
1311 value: "12345".to_string(),
1312 };
1313 let sel = optimizer.estimate_selectivity(&pred, &stats);
1314 assert!(sel < 0.001); let pred = Predicate::Gt {
1320 column: "score".to_string(),
1321 value: "75".to_string(),
1322 };
1323 let sel = optimizer.estimate_selectivity(&pred, &stats);
1324 assert!(sel > 0.4 && sel < 0.6); }
1326
1327 #[test]
1328 fn test_access_path_selection() {
1329 let config = CostModelConfig::default();
1330 let optimizer = CostBasedOptimizer::new(config);
1331
1332 let stats = create_test_stats();
1333 optimizer.update_stats(stats);
1334
1335 let pred = Predicate::Eq {
1337 column: "id".to_string(),
1338 value: "12345".to_string(),
1339 };
1340 let plan = optimizer.optimize(
1341 "users",
1342 vec!["id".to_string(), "score".to_string()],
1343 Some(pred),
1344 vec![],
1345 None,
1346 );
1347
1348 match plan {
1349 PhysicalPlan::IndexSeek { index, .. } => {
1350 assert_eq!(index, "pk_users");
1351 }
1352 _ => panic!("Expected IndexSeek for equality on primary key"),
1353 }
1354 }
1355
1356 #[test]
1357 fn test_token_budget_limit() {
1358 let config = CostModelConfig::default();
1359 let optimizer = CostBasedOptimizer::new(config).with_token_budget(2048, 25.0);
1360
1361 let plan = optimizer.optimize("users", vec!["id".to_string()], None, vec![], None);
1364
1365 match plan {
1366 PhysicalPlan::Limit { limit, .. } => {
1367 assert!(limit <= 80);
1368 }
1369 _ => panic!("Expected Limit to be injected"),
1370 }
1371 }
1372
1373 #[test]
1374 fn test_explain_output() {
1375 let config = CostModelConfig::default();
1376 let optimizer = CostBasedOptimizer::new(config);
1377
1378 let stats = create_test_stats();
1379 optimizer.update_stats(stats);
1380
1381 let plan = optimizer.optimize(
1382 "users",
1383 vec!["id".to_string(), "score".to_string()],
1384 Some(Predicate::Gt {
1385 column: "score".to_string(),
1386 value: "80".to_string(),
1387 }),
1388 vec![("score".to_string(), SortDirection::Descending)],
1389 Some(10),
1390 );
1391
1392 let explain = optimizer.explain(&plan);
1393 assert!(explain.contains("Limit"));
1394 assert!(explain.contains("Sort"));
1395 }
1396}