1use parking_lot::RwLock;
47use std::collections::{HashMap, HashSet};
48use std::sync::Arc;
49use std::time::{SystemTime, UNIX_EPOCH};
50
51#[derive(Debug, Clone)]
57pub struct CostModelConfig {
58 pub c_seq: f64,
60 pub c_random: f64,
62 pub c_filter: f64,
64 pub c_compare: f64,
66 pub block_size: usize,
68 pub btree_fanout: usize,
70 pub memory_bandwidth: f64,
72}
73
74impl Default for CostModelConfig {
75 fn default() -> Self {
76 Self {
77 c_seq: 0.1, c_random: 5.0, c_filter: 0.001, c_compare: 0.0001, block_size: 4096, btree_fanout: 100, memory_bandwidth: 10000.0, }
85 }
86}
87
88#[derive(Debug, Clone)]
94pub struct TableStats {
95 pub name: String,
97 pub row_count: u64,
99 pub size_bytes: u64,
101 pub column_stats: HashMap<String, ColumnStats>,
103 pub indices: Vec<IndexStats>,
105 pub last_updated: u64,
107}
108
109#[derive(Debug, Clone)]
111pub struct ColumnStats {
112 pub name: String,
114 pub distinct_count: u64,
116 pub null_count: u64,
118 pub min_value: Option<String>,
120 pub max_value: Option<String>,
122 pub avg_length: f64,
124 pub mcv: Vec<(String, f64)>,
126 pub histogram: Option<Histogram>,
128}
129
130#[derive(Debug, Clone)]
132pub struct Histogram {
133 pub boundaries: Vec<f64>,
135 pub counts: Vec<u64>,
137 pub total_rows: u64,
139}
140
141impl Histogram {
142 pub fn estimate_range_selectivity(&self, min: Option<f64>, max: Option<f64>) -> f64 {
144 if self.total_rows == 0 {
145 return 0.5; }
147
148 let mut selected_rows = 0u64;
149
150 for (i, &count) in self.counts.iter().enumerate() {
151 let bucket_min = if i == 0 {
152 f64::NEG_INFINITY
153 } else {
154 self.boundaries[i - 1]
155 };
156 let bucket_max = if i == self.boundaries.len() {
157 f64::INFINITY
158 } else {
159 self.boundaries[i]
160 };
161
162 let overlaps = match (min, max) {
163 (Some(min_val), Some(max_val)) => bucket_max >= min_val && bucket_min <= max_val,
164 (Some(min_val), None) => bucket_max >= min_val,
165 (None, Some(max_val)) => bucket_min <= max_val,
166 (None, None) => true,
167 };
168
169 if overlaps {
170 selected_rows += count;
171 }
172 }
173
174 selected_rows as f64 / self.total_rows as f64
175 }
176}
177
178#[derive(Debug, Clone)]
180pub struct IndexStats {
181 pub name: String,
183 pub columns: Vec<String>,
185 pub is_primary: bool,
187 pub is_unique: bool,
189 pub index_type: IndexType,
191 pub leaf_pages: u64,
193 pub height: u32,
195 pub avg_leaf_density: f64,
197}
198
199#[derive(Debug, Clone, Copy, PartialEq, Eq)]
201pub enum IndexType {
202 BTree,
203 Hash,
204 LSM,
205 Learned,
206 Vector,
207 Bloom,
208}
209
210#[derive(Debug, Clone)]
216pub enum Predicate {
217 Eq { column: String, value: String },
219 Ne { column: String, value: String },
221 Lt { column: String, value: String },
223 Le { column: String, value: String },
225 Gt { column: String, value: String },
227 Ge { column: String, value: String },
229 Between {
231 column: String,
232 min: String,
233 max: String,
234 },
235 In { column: String, values: Vec<String> },
237 Like { column: String, pattern: String },
239 IsNull { column: String },
241 IsNotNull { column: String },
243 And(Box<Predicate>, Box<Predicate>),
245 Or(Box<Predicate>, Box<Predicate>),
247 Not(Box<Predicate>),
249}
250
251impl Predicate {
252 pub fn referenced_columns(&self) -> HashSet<String> {
254 let mut cols = HashSet::new();
255 self.collect_columns(&mut cols);
256 cols
257 }
258
259 fn collect_columns(&self, cols: &mut HashSet<String>) {
260 match self {
261 Self::Eq { column, .. }
262 | Self::Ne { column, .. }
263 | Self::Lt { column, .. }
264 | Self::Le { column, .. }
265 | Self::Gt { column, .. }
266 | Self::Ge { column, .. }
267 | Self::Between { column, .. }
268 | Self::In { column, .. }
269 | Self::Like { column, .. }
270 | Self::IsNull { column }
271 | Self::IsNotNull { column } => {
272 cols.insert(column.clone());
273 }
274 Self::And(left, right) | Self::Or(left, right) => {
275 left.collect_columns(cols);
276 right.collect_columns(cols);
277 }
278 Self::Not(inner) => inner.collect_columns(cols),
279 }
280 }
281}
282
283#[derive(Debug, Clone)]
289pub enum PhysicalPlan {
290 TableScan {
292 table: String,
293 columns: Vec<String>,
294 predicate: Option<Box<Predicate>>,
295 estimated_rows: u64,
296 estimated_cost: f64,
297 },
298 IndexSeek {
300 table: String,
301 index: String,
302 columns: Vec<String>,
303 key_range: KeyRange,
304 predicate: Option<Box<Predicate>>,
305 estimated_rows: u64,
306 estimated_cost: f64,
307 },
308 Filter {
310 input: Box<PhysicalPlan>,
311 predicate: Predicate,
312 estimated_rows: u64,
313 estimated_cost: f64,
314 },
315 Project {
317 input: Box<PhysicalPlan>,
318 columns: Vec<String>,
319 estimated_cost: f64,
320 },
321 Sort {
323 input: Box<PhysicalPlan>,
324 order_by: Vec<(String, SortDirection)>,
325 estimated_cost: f64,
326 },
327 Limit {
329 input: Box<PhysicalPlan>,
330 limit: u64,
331 offset: u64,
332 estimated_cost: f64,
333 },
334 NestedLoopJoin {
336 outer: Box<PhysicalPlan>,
337 inner: Box<PhysicalPlan>,
338 condition: Predicate,
339 join_type: JoinType,
340 estimated_rows: u64,
341 estimated_cost: f64,
342 },
343 HashJoin {
345 build: Box<PhysicalPlan>,
346 probe: Box<PhysicalPlan>,
347 build_keys: Vec<String>,
348 probe_keys: Vec<String>,
349 join_type: JoinType,
350 estimated_rows: u64,
351 estimated_cost: f64,
352 },
353 MergeJoin {
355 left: Box<PhysicalPlan>,
356 right: Box<PhysicalPlan>,
357 left_keys: Vec<String>,
358 right_keys: Vec<String>,
359 join_type: JoinType,
360 estimated_rows: u64,
361 estimated_cost: f64,
362 },
363 Aggregate {
365 input: Box<PhysicalPlan>,
366 group_by: Vec<String>,
367 aggregates: Vec<AggregateExpr>,
368 estimated_rows: u64,
369 estimated_cost: f64,
370 },
371}
372
373#[derive(Debug, Clone)]
375pub struct KeyRange {
376 pub start: Option<Vec<u8>>,
377 pub end: Option<Vec<u8>>,
378 pub start_inclusive: bool,
379 pub end_inclusive: bool,
380}
381
382impl KeyRange {
383 pub fn all() -> Self {
384 Self {
385 start: None,
386 end: None,
387 start_inclusive: true,
388 end_inclusive: true,
389 }
390 }
391
392 pub fn point(key: Vec<u8>) -> Self {
393 Self {
394 start: Some(key.clone()),
395 end: Some(key),
396 start_inclusive: true,
397 end_inclusive: true,
398 }
399 }
400
401 pub fn range(start: Option<Vec<u8>>, end: Option<Vec<u8>>, inclusive: bool) -> Self {
402 Self {
403 start,
404 end,
405 start_inclusive: inclusive,
406 end_inclusive: inclusive,
407 }
408 }
409}
410
411#[derive(Debug, Clone, Copy, PartialEq, Eq)]
413pub enum SortDirection {
414 Ascending,
415 Descending,
416}
417
418#[derive(Debug, Clone, Copy, PartialEq, Eq)]
420pub enum JoinType {
421 Inner,
422 Left,
423 Right,
424 Full,
425 Cross,
426}
427
428#[derive(Debug, Clone)]
430pub struct AggregateExpr {
431 pub function: AggregateFunction,
432 pub column: Option<String>,
433 pub alias: String,
434}
435
436#[derive(Debug, Clone, Copy, PartialEq, Eq)]
438pub enum AggregateFunction {
439 Count,
440 Sum,
441 Avg,
442 Min,
443 Max,
444 CountDistinct,
445}
446
447pub struct CostBasedOptimizer {
453 config: CostModelConfig,
455 stats_cache: Arc<RwLock<HashMap<String, TableStats>>>,
457 token_budget: Option<u64>,
459 tokens_per_row: f64,
461 plan_cache: Arc<RwLock<HashMap<u64, (PhysicalPlan, u64)>>>,
463 plan_cache_ttl_us: u64,
465}
466
467impl CostBasedOptimizer {
468 pub fn new(config: CostModelConfig) -> Self {
469 Self {
470 config,
471 stats_cache: Arc::new(RwLock::new(HashMap::new())),
472 token_budget: None,
473 tokens_per_row: 25.0, plan_cache: Arc::new(RwLock::new(HashMap::new())),
475 plan_cache_ttl_us: 5_000_000, }
477 }
478
479 pub fn with_plan_cache_ttl_ms(mut self, ttl_ms: u64) -> Self {
481 self.plan_cache_ttl_us = ttl_ms * 1000;
482 self
483 }
484
485 pub fn with_token_budget(mut self, budget: u64, tokens_per_row: f64) -> Self {
487 self.token_budget = Some(budget);
488 self.tokens_per_row = tokens_per_row;
489 self
490 }
491
492 pub fn update_stats(&self, stats: TableStats) {
494 self.stats_cache.write().insert(stats.name.clone(), stats);
495 }
496
497 pub fn get_stats(&self, table: &str) -> Option<TableStats> {
499 self.stats_cache.read().get(table).cloned()
500 }
501
502 pub fn optimize(
504 &self,
505 table: &str,
506 columns: Vec<String>,
507 predicate: Option<Predicate>,
508 order_by: Vec<(String, SortDirection)>,
509 limit: Option<u64>,
510 ) -> PhysicalPlan {
511 let stats = self.get_stats(table);
512
513 let effective_limit = self.calculate_token_limit(limit);
515
516 let mut plan = self.choose_access_path(table, &columns, predicate.as_ref(), &stats);
518
519 plan = self.apply_projection_pushdown(plan, columns.clone());
521
522 if !order_by.is_empty() {
524 plan = self.add_sort(plan, order_by, &stats);
525 }
526
527 if let Some(lim) = effective_limit {
529 plan = PhysicalPlan::Limit {
530 estimated_cost: 0.0,
531 input: Box::new(plan),
532 limit: lim,
533 offset: 0,
534 };
535 }
536
537 plan
538 }
539
540 fn calculate_token_limit(&self, user_limit: Option<u64>) -> Option<u64> {
542 match (self.token_budget, user_limit) {
543 (Some(budget), Some(limit)) => {
544 let header_tokens = 50u64;
545 let usable = budget.saturating_sub(header_tokens);
546 let max_rows = (usable as f64 / self.tokens_per_row).max(1.0) as u64;
547 Some(limit.min(max_rows))
548 }
549 (Some(budget), None) => {
550 let header_tokens = 50u64;
551 let usable = budget.saturating_sub(header_tokens);
552 let max_rows = (usable as f64 / self.tokens_per_row).max(1.0) as u64;
553 Some(max_rows)
554 }
555 (None, limit) => limit,
556 }
557 }
558
559 fn choose_access_path(
561 &self,
562 table: &str,
563 columns: &[String],
564 predicate: Option<&Predicate>,
565 stats: &Option<TableStats>,
566 ) -> PhysicalPlan {
567 let row_count = stats.as_ref().map(|s| s.row_count).unwrap_or(10000);
568 let size_bytes = stats
569 .as_ref()
570 .map(|s| s.size_bytes)
571 .unwrap_or(row_count * 100);
572
573 let scan_cost = self.estimate_scan_cost(row_count, size_bytes, predicate);
575
576 let mut best_index_cost = f64::MAX;
578 let mut best_index: Option<&IndexStats> = None;
579
580 if let Some(table_stats) = stats.as_ref()
581 && let Some(pred) = predicate
582 {
583 let pred_columns = pred.referenced_columns();
584
585 for index in &table_stats.indices {
586 if self.index_covers_predicate(index, &pred_columns) {
587 let selectivity = self.estimate_selectivity(pred, table_stats);
588 let index_cost = self.estimate_index_cost(index, row_count, selectivity);
589
590 if index_cost < best_index_cost {
591 best_index_cost = index_cost;
592 best_index = Some(index);
593 }
594 }
595 }
596 }
597
598 if best_index_cost < scan_cost {
600 let index = best_index.unwrap();
601 let selectivity = predicate
602 .map(|p| self.estimate_selectivity(p, stats.as_ref().unwrap()))
603 .unwrap_or(1.0);
604
605 PhysicalPlan::IndexSeek {
606 table: table.to_string(),
607 index: index.name.clone(),
608 columns: columns.to_vec(),
609 key_range: predicate
610 .map(|p| Self::derive_key_range(p))
611 .unwrap_or_else(KeyRange::all),
612 predicate: predicate.map(|p| Box::new(p.clone())),
613 estimated_rows: (row_count as f64 * selectivity).max(1.0) as u64,
614 estimated_cost: best_index_cost,
615 }
616 } else {
617 PhysicalPlan::TableScan {
618 table: table.to_string(),
619 columns: columns.to_vec(),
620 predicate: predicate.map(|p| Box::new(p.clone())),
621 estimated_rows: row_count,
622 estimated_cost: scan_cost,
623 }
624 }
625 }
626
627 fn index_covers_predicate(&self, index: &IndexStats, pred_columns: &HashSet<String>) -> bool {
629 if let Some(first_col) = index.columns.first() {
631 pred_columns.contains(first_col)
632 } else {
633 false
634 }
635 }
636
637 fn estimate_scan_cost(
642 &self,
643 row_count: u64,
644 size_bytes: u64,
645 _predicate: Option<&Predicate>,
646 ) -> f64 {
647 let blocks = (size_bytes as f64 / self.config.block_size as f64).ceil().max(1.0) as u64;
648
649 let io_cost = blocks as f64 * self.config.c_seq;
651
652 let cpu_cost = row_count as f64 * self.config.c_filter;
654
655 io_cost + cpu_cost
656 }
657
658 fn estimate_index_cost(&self, index: &IndexStats, total_rows: u64, selectivity: f64) -> f64 {
662 let tree_cost = index.height as f64 * self.config.c_random;
664
665 let matching_rows = (total_rows as f64 * selectivity) as u64;
667 let leaf_pages_scanned = (matching_rows as f64 / index.avg_leaf_density).ceil() as u64;
668 let leaf_cost = leaf_pages_scanned as f64 * self.config.c_seq;
669
670 let fetch_cost = if index.is_primary {
672 0.0 } else {
674 matching_rows.min(1000) as f64 * self.config.c_random * 0.1 };
676
677 tree_cost + leaf_cost + fetch_cost
678 }
679
680 #[allow(clippy::only_used_in_recursion)]
682 fn estimate_selectivity(&self, predicate: &Predicate, stats: &TableStats) -> f64 {
683 match predicate {
684 Predicate::Eq { column, value } => {
685 if let Some(col_stats) = stats.column_stats.get(column) {
686 for (mcv_val, freq) in &col_stats.mcv {
688 if mcv_val == value {
689 return *freq;
690 }
691 }
692 1.0 / col_stats.distinct_count.max(1) as f64
694 } else {
695 0.1 }
697 }
698 Predicate::Ne { .. } => 0.9, Predicate::Lt { column, value }
700 | Predicate::Le { column, value }
701 | Predicate::Gt { column, value }
702 | Predicate::Ge { column, value } => {
703 if let Some(col_stats) = stats.column_stats.get(column) {
704 if let Some(ref hist) = col_stats.histogram {
705 let val: f64 = value.parse().unwrap_or(0.0);
706 match predicate {
707 Predicate::Lt { .. } | Predicate::Le { .. } => {
708 hist.estimate_range_selectivity(None, Some(val))
709 }
710 _ => hist.estimate_range_selectivity(Some(val), None),
711 }
712 } else {
713 0.25 }
715 } else {
716 0.25
717 }
718 }
719 Predicate::Between { column, min, max } => {
720 if let Some(col_stats) = stats.column_stats.get(column) {
721 if let Some(ref hist) = col_stats.histogram {
722 let min_val: f64 = min.parse().unwrap_or(0.0);
723 let max_val: f64 = max.parse().unwrap_or(f64::MAX);
724 hist.estimate_range_selectivity(Some(min_val), Some(max_val))
725 } else {
726 0.2
727 }
728 } else {
729 0.2
730 }
731 }
732 Predicate::In { column, values } => {
733 if let Some(col_stats) = stats.column_stats.get(column) {
734 (values.len() as f64 / col_stats.distinct_count.max(1) as f64).min(1.0)
735 } else {
736 (values.len() as f64 * 0.1).min(0.5)
737 }
738 }
739 Predicate::Like { .. } => 0.15, Predicate::IsNull { column } => {
741 if let Some(col_stats) = stats.column_stats.get(column) {
742 col_stats.null_count as f64 / stats.row_count.max(1) as f64
743 } else {
744 0.01
745 }
746 }
747 Predicate::IsNotNull { column } => {
748 if let Some(col_stats) = stats.column_stats.get(column) {
749 1.0 - (col_stats.null_count as f64 / stats.row_count.max(1) as f64)
750 } else {
751 0.99
752 }
753 }
754 Predicate::And(left, right) => {
755 self.estimate_selectivity(left, stats) * self.estimate_selectivity(right, stats)
757 }
758 Predicate::Or(left, right) => {
759 let s1 = self.estimate_selectivity(left, stats);
760 let s2 = self.estimate_selectivity(right, stats);
761 (s1 + s2 - s1 * s2).min(1.0)
763 }
764 Predicate::Not(inner) => 1.0 - self.estimate_selectivity(inner, stats),
765 }
766 }
767
768 fn derive_key_range(predicate: &Predicate) -> KeyRange {
770 match predicate {
771 Predicate::Eq { value, .. } => KeyRange::point(value.as_bytes().to_vec()),
772 Predicate::Lt { value, .. } | Predicate::Le { value, .. } => {
773 KeyRange::range(None, Some(value.as_bytes().to_vec()), matches!(predicate, Predicate::Le { .. }))
774 }
775 Predicate::Gt { value, .. } | Predicate::Ge { value, .. } => {
776 KeyRange::range(Some(value.as_bytes().to_vec()), None, matches!(predicate, Predicate::Ge { .. }))
777 }
778 Predicate::Between { min, max, .. } => KeyRange {
779 start: Some(min.as_bytes().to_vec()),
780 end: Some(max.as_bytes().to_vec()),
781 start_inclusive: true,
782 end_inclusive: true,
783 },
784 Predicate::And(left, _) => Self::derive_key_range(left),
785 _ => KeyRange::all(),
786 }
787 }
788
789 fn apply_projection_pushdown(&self, plan: PhysicalPlan, columns: Vec<String>) -> PhysicalPlan {
793 match plan {
794 PhysicalPlan::TableScan {
795 ref table,
796 predicate,
797 estimated_rows,
798 estimated_cost,
799 columns: ref all_columns,
800 ..
801 } => {
802 let col_ratio = if all_columns.is_empty() || columns.is_empty() {
804 1.0
805 } else {
806 (columns.len() as f64 / all_columns.len().max(1) as f64).clamp(0.1, 1.0)
807 };
808 PhysicalPlan::TableScan {
809 table: table.clone(),
810 columns,
811 predicate,
812 estimated_rows,
813 estimated_cost: estimated_cost * col_ratio,
814 }
815 }
816 PhysicalPlan::IndexSeek {
817 table,
818 index,
819 key_range,
820 predicate,
821 estimated_rows,
822 estimated_cost,
823 ..
824 } => {
825 PhysicalPlan::IndexSeek {
826 table,
827 index,
828 columns, key_range,
830 predicate,
831 estimated_rows,
832 estimated_cost,
833 }
834 }
835 other => PhysicalPlan::Project {
836 input: Box::new(other),
837 columns,
838 estimated_cost: 0.0,
839 },
840 }
841 }
842
843 fn add_sort(
845 &self,
846 plan: PhysicalPlan,
847 order_by: Vec<(String, SortDirection)>,
848 _stats: &Option<TableStats>,
849 ) -> PhysicalPlan {
850 let estimated_rows = self.get_plan_rows(&plan);
851 let sort_cost = if estimated_rows > 0 {
852 estimated_rows as f64 * (estimated_rows as f64).log2() * self.config.c_compare
853 } else {
854 0.0
855 };
856
857 PhysicalPlan::Sort {
858 input: Box::new(plan),
859 order_by,
860 estimated_cost: sort_cost,
861 }
862 }
863
864 #[allow(clippy::only_used_in_recursion)]
866 fn get_plan_rows(&self, plan: &PhysicalPlan) -> u64 {
867 match plan {
868 PhysicalPlan::TableScan { estimated_rows, .. }
869 | PhysicalPlan::IndexSeek { estimated_rows, .. }
870 | PhysicalPlan::Filter { estimated_rows, .. }
871 | PhysicalPlan::Aggregate { estimated_rows, .. }
872 | PhysicalPlan::NestedLoopJoin { estimated_rows, .. }
873 | PhysicalPlan::HashJoin { estimated_rows, .. }
874 | PhysicalPlan::MergeJoin { estimated_rows, .. } => *estimated_rows,
875 PhysicalPlan::Project { input, .. } | PhysicalPlan::Sort { input, .. } => {
876 self.get_plan_rows(input)
877 }
878 PhysicalPlan::Limit { limit, .. } => *limit,
879 }
880 }
881
882 #[allow(clippy::only_used_in_recursion)]
884 pub fn get_plan_cost(&self, plan: &PhysicalPlan) -> f64 {
885 match plan {
886 PhysicalPlan::TableScan { estimated_cost, .. } => *estimated_cost,
887 PhysicalPlan::IndexSeek { estimated_cost, .. } => *estimated_cost,
888 PhysicalPlan::Filter {
889 estimated_cost,
890 input,
891 ..
892 } => *estimated_cost + self.get_plan_cost(input),
893 PhysicalPlan::Project {
894 estimated_cost,
895 input,
896 ..
897 } => *estimated_cost + self.get_plan_cost(input),
898 PhysicalPlan::Sort {
899 estimated_cost,
900 input,
901 ..
902 } => *estimated_cost + self.get_plan_cost(input),
903 PhysicalPlan::Limit {
904 estimated_cost,
905 input,
906 ..
907 } => *estimated_cost + self.get_plan_cost(input),
908 PhysicalPlan::NestedLoopJoin {
909 estimated_cost,
910 outer,
911 inner,
912 ..
913 } => *estimated_cost + self.get_plan_cost(outer) + self.get_plan_cost(inner),
914 PhysicalPlan::HashJoin {
915 estimated_cost,
916 build,
917 probe,
918 ..
919 } => *estimated_cost + self.get_plan_cost(build) + self.get_plan_cost(probe),
920 PhysicalPlan::MergeJoin {
921 estimated_cost,
922 left,
923 right,
924 ..
925 } => *estimated_cost + self.get_plan_cost(left) + self.get_plan_cost(right),
926 PhysicalPlan::Aggregate {
927 estimated_cost,
928 input,
929 ..
930 } => *estimated_cost + self.get_plan_cost(input),
931 }
932 }
933
934 pub fn explain(&self, plan: &PhysicalPlan) -> String {
936 self.explain_impl(plan, 0)
937 }
938
939 fn explain_impl(&self, plan: &PhysicalPlan, indent: usize) -> String {
940 let prefix = " ".repeat(indent);
941 let cost = self.get_plan_cost(plan);
942
943 match plan {
944 PhysicalPlan::TableScan {
945 table,
946 columns,
947 estimated_rows,
948 ..
949 } => {
950 format!(
951 "{}TableScan [table={}, columns={:?}, rows={}, cost={:.2}ms]",
952 prefix, table, columns, estimated_rows, cost
953 )
954 }
955 PhysicalPlan::IndexSeek {
956 table,
957 index,
958 columns,
959 estimated_rows,
960 ..
961 } => {
962 format!(
963 "{}IndexSeek [table={}, index={}, columns={:?}, rows={}, cost={:.2}ms]",
964 prefix, table, index, columns, estimated_rows, cost
965 )
966 }
967 PhysicalPlan::Filter {
968 input,
969 estimated_rows,
970 ..
971 } => {
972 format!(
973 "{}Filter [rows={}, cost={:.2}ms]\n{}",
974 prefix,
975 estimated_rows,
976 cost,
977 self.explain_impl(input, indent + 1)
978 )
979 }
980 PhysicalPlan::Project { input, columns, .. } => {
981 format!(
982 "{}Project [columns={:?}, cost={:.2}ms]\n{}",
983 prefix,
984 columns,
985 cost,
986 self.explain_impl(input, indent + 1)
987 )
988 }
989 PhysicalPlan::Sort {
990 input, order_by, ..
991 } => {
992 let order: Vec<_> = order_by
993 .iter()
994 .map(|(c, d)| format!("{} {:?}", c, d))
995 .collect();
996 format!(
997 "{}Sort [order={:?}, cost={:.2}ms]\n{}",
998 prefix,
999 order,
1000 cost,
1001 self.explain_impl(input, indent + 1)
1002 )
1003 }
1004 PhysicalPlan::Limit {
1005 input,
1006 limit,
1007 offset,
1008 ..
1009 } => {
1010 format!(
1011 "{}Limit [limit={}, offset={}, cost={:.2}ms]\n{}",
1012 prefix,
1013 limit,
1014 offset,
1015 cost,
1016 self.explain_impl(input, indent + 1)
1017 )
1018 }
1019 PhysicalPlan::HashJoin {
1020 build,
1021 probe,
1022 join_type,
1023 estimated_rows,
1024 ..
1025 } => {
1026 format!(
1027 "{}HashJoin [type={:?}, rows={}, cost={:.2}ms]\n{}\n{}",
1028 prefix,
1029 join_type,
1030 estimated_rows,
1031 cost,
1032 self.explain_impl(build, indent + 1),
1033 self.explain_impl(probe, indent + 1)
1034 )
1035 }
1036 PhysicalPlan::MergeJoin {
1037 left,
1038 right,
1039 join_type,
1040 estimated_rows,
1041 ..
1042 } => {
1043 format!(
1044 "{}MergeJoin [type={:?}, rows={}, cost={:.2}ms]\n{}\n{}",
1045 prefix,
1046 join_type,
1047 estimated_rows,
1048 cost,
1049 self.explain_impl(left, indent + 1),
1050 self.explain_impl(right, indent + 1)
1051 )
1052 }
1053 PhysicalPlan::NestedLoopJoin {
1054 outer,
1055 inner,
1056 join_type,
1057 estimated_rows,
1058 ..
1059 } => {
1060 format!(
1061 "{}NestedLoopJoin [type={:?}, rows={}, cost={:.2}ms]\n{}\n{}",
1062 prefix,
1063 join_type,
1064 estimated_rows,
1065 cost,
1066 self.explain_impl(outer, indent + 1),
1067 self.explain_impl(inner, indent + 1)
1068 )
1069 }
1070 PhysicalPlan::Aggregate {
1071 input,
1072 group_by,
1073 aggregates,
1074 estimated_rows,
1075 ..
1076 } => {
1077 let aggs: Vec<_> = aggregates
1078 .iter()
1079 .map(|a| format!("{:?}({})", a.function, a.column.as_deref().unwrap_or("*")))
1080 .collect();
1081 format!(
1082 "{}Aggregate [group_by={:?}, aggs={:?}, rows={}, cost={:.2}ms]\n{}",
1083 prefix,
1084 group_by,
1085 aggs,
1086 estimated_rows,
1087 cost,
1088 self.explain_impl(input, indent + 1)
1089 )
1090 }
1091 }
1092 }
1093}
1094
1095impl CostBasedOptimizer {
1100 pub fn evict_stale_plans(&self) {
1102 let now = Self::now_us();
1103 self.plan_cache
1104 .write()
1105 .retain(|_, (_, ts)| now.saturating_sub(*ts) < self.plan_cache_ttl_us);
1106 }
1107
1108 pub fn invalidate_plan_cache(&self) {
1110 self.plan_cache.write().clear();
1111 }
1112
1113 pub fn collect_stats(
1118 &self,
1119 table_name: &str,
1120 row_count: u64,
1121 size_bytes: u64,
1122 column_values: HashMap<String, Vec<String>>,
1123 indices: Vec<IndexStats>,
1124 ) {
1125 let mut column_stats = HashMap::new();
1126 for (col_name, values) in &column_values {
1127 let distinct: HashSet<&String> = values.iter().collect();
1128 let null_count = values.iter().filter(|v| v.is_empty()).count() as u64;
1129 let avg_length = if values.is_empty() {
1130 0.0
1131 } else {
1132 values.iter().map(|v| v.len()).sum::<usize>() as f64 / values.len() as f64
1133 };
1134
1135 let is_numeric = values.iter().take(10).all(|v| v.parse::<f64>().is_ok());
1137 let histogram = if is_numeric && values.len() >= 10 {
1138 let mut nums: Vec<f64> = values.iter().filter_map(|v| v.parse().ok()).collect();
1139 nums.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
1140 let bucket_count = 10.min(nums.len());
1141 let bucket_size = nums.len() / bucket_count;
1142 let mut boundaries = Vec::new();
1143 let mut counts = Vec::new();
1144 for i in 0..bucket_count {
1145 let end = if i == bucket_count - 1 {
1146 nums.len()
1147 } else {
1148 (i + 1) * bucket_size
1149 };
1150 let start = i * bucket_size;
1151 boundaries.push(nums[end - 1]);
1152 counts.push((end - start) as u64);
1153 }
1154 Some(Histogram {
1155 boundaries,
1156 counts,
1157 total_rows: nums.len() as u64,
1158 })
1159 } else {
1160 None
1161 };
1162
1163 let mut freq_map: HashMap<&String, usize> = HashMap::new();
1165 for v in values {
1166 *freq_map.entry(v).or_insert(0) += 1;
1167 }
1168 let total = values.len() as f64;
1169 let mut mcv: Vec<(String, f64)> = freq_map
1170 .iter()
1171 .map(|(k, &v)| ((*k).clone(), v as f64 / total))
1172 .collect();
1173 mcv.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
1174 mcv.truncate(5);
1175
1176 column_stats.insert(
1177 col_name.clone(),
1178 ColumnStats {
1179 name: col_name.clone(),
1180 distinct_count: distinct.len() as u64,
1181 null_count,
1182 min_value: values.iter().min().cloned(),
1183 max_value: values.iter().max().cloned(),
1184 avg_length,
1185 mcv,
1186 histogram,
1187 },
1188 );
1189 }
1190
1191 self.update_stats(TableStats {
1192 name: table_name.to_string(),
1193 row_count,
1194 size_bytes,
1195 column_stats,
1196 indices,
1197 last_updated: Self::now_us(),
1198 });
1199
1200 self.invalidate_plan_cache();
1202 }
1203
1204 pub fn stats_age_us(&self, table: &str) -> Option<u64> {
1206 self.stats_cache.read().get(table).map(|s| {
1207 Self::now_us().saturating_sub(s.last_updated)
1208 })
1209 }
1210
1211 fn now_us() -> u64 {
1212 SystemTime::now()
1213 .duration_since(UNIX_EPOCH)
1214 .unwrap_or_default()
1215 .as_micros() as u64
1216 }
1217}
1218
1219pub struct JoinOrderOptimizer {
1225 stats: HashMap<String, TableStats>,
1227 config: CostModelConfig,
1229}
1230
1231impl JoinOrderOptimizer {
1232 pub fn new(config: CostModelConfig) -> Self {
1233 Self {
1234 stats: HashMap::new(),
1235 config,
1236 }
1237 }
1238
1239 pub fn add_stats(&mut self, stats: TableStats) {
1241 self.stats.insert(stats.name.clone(), stats);
1242 }
1243
1244 pub fn find_optimal_order(
1249 &self,
1250 tables: &[String],
1251 join_conditions: &[(String, String, String, String)], ) -> Vec<(String, String)> {
1253 let n = tables.len();
1254 if n <= 1 {
1255 return vec![];
1256 }
1257
1258 let mut dp: HashMap<u32, (f64, Vec<(String, String)>)> = HashMap::new();
1260
1261 for (i, _table) in tables.iter().enumerate() {
1263 let mask = 1u32 << i;
1264 dp.insert(mask, (0.0, vec![]));
1265 }
1266
1267 for size in 2..=n {
1269 for mask in 0..(1u32 << n) {
1270 if mask.count_ones() != size as u32 {
1271 continue;
1272 }
1273
1274 let mut best_cost = f64::MAX;
1275 let mut best_order = vec![];
1276
1277 for sub in 1..mask {
1279 if sub & mask != sub || sub == 0 {
1280 continue;
1281 }
1282 let other = mask ^ sub;
1283 if other == 0 {
1284 continue;
1285 }
1286
1287 if !self.has_join_condition(tables, sub, other, join_conditions) {
1289 continue;
1290 }
1291
1292 if let (Some((cost1, order1)), Some((cost2, order2))) =
1293 (dp.get(&sub), dp.get(&other))
1294 {
1295 let join_cost = self.estimate_join_cost(tables, sub, other);
1296 let total_cost = cost1 + cost2 + join_cost;
1297
1298 if total_cost < best_cost {
1299 best_cost = total_cost;
1300 best_order = order1.clone();
1301 best_order.extend(order2.clone());
1302
1303 let (t1, t2) =
1305 self.get_join_tables(tables, sub, other, join_conditions);
1306 if let Some((t1, t2)) = Some((t1, t2)) {
1307 best_order.push((t1, t2));
1308 }
1309 }
1310 }
1311 }
1312
1313 if best_cost < f64::MAX {
1314 dp.insert(mask, (best_cost, best_order));
1315 }
1316 }
1317 }
1318
1319 let full_mask = (1u32 << n) - 1;
1320 dp.get(&full_mask)
1321 .map(|(_, order)| order.clone())
1322 .unwrap_or_default()
1323 }
1324
1325 fn has_join_condition(
1326 &self,
1327 tables: &[String],
1328 mask1: u32,
1329 mask2: u32,
1330 conditions: &[(String, String, String, String)],
1331 ) -> bool {
1332 for (t1, _, t2, _) in conditions {
1333 let idx1 = tables.iter().position(|t| t == t1);
1334 let idx2 = tables.iter().position(|t| t == t2);
1335
1336 if let (Some(i1), Some(i2)) = (idx1, idx2) {
1337 let in_mask1 = (mask1 >> i1) & 1 == 1;
1338 let in_mask2 = (mask2 >> i2) & 1 == 1;
1339
1340 if in_mask1 && in_mask2 {
1341 return true;
1342 }
1343 }
1344 }
1345 false
1346 }
1347
1348 fn get_join_tables(
1349 &self,
1350 tables: &[String],
1351 mask1: u32,
1352 mask2: u32,
1353 conditions: &[(String, String, String, String)],
1354 ) -> (String, String) {
1355 for (t1, _, t2, _) in conditions {
1356 let idx1 = tables.iter().position(|t| t == t1);
1357 let idx2 = tables.iter().position(|t| t == t2);
1358
1359 if let (Some(i1), Some(i2)) = (idx1, idx2) {
1360 let t1_in_mask1 = (mask1 >> i1) & 1 == 1;
1361 let t2_in_mask2 = (mask2 >> i2) & 1 == 1;
1362
1363 if t1_in_mask1 && t2_in_mask2 {
1364 return (t1.clone(), t2.clone());
1365 }
1366 }
1367 }
1368 (String::new(), String::new())
1369 }
1370
1371 fn estimate_join_cost(&self, tables: &[String], mask1: u32, mask2: u32) -> f64 {
1372 let rows1 = self.estimate_rows_for_mask(tables, mask1);
1373 let rows2 = self.estimate_rows_for_mask(tables, mask2);
1374
1375 let build_cost = rows1 as f64 * self.config.c_filter;
1378 let probe_cost = rows2 as f64 * self.config.c_filter;
1379
1380 build_cost + probe_cost
1381 }
1382
1383 fn estimate_rows_for_mask(&self, tables: &[String], mask: u32) -> u64 {
1384 let mut total = 1u64;
1385
1386 for (i, table) in tables.iter().enumerate() {
1387 if (mask >> i) & 1 == 1 {
1388 let rows = self.stats.get(table).map(|s| s.row_count).unwrap_or(1000);
1389 total = total.saturating_mul(rows);
1390 }
1391 }
1392
1393 let num_tables = mask.count_ones();
1395 if num_tables > 1 {
1396 total = (total as f64 * 0.1f64.powi(num_tables as i32 - 1)) as u64;
1397 }
1398
1399 total.max(1)
1400 }
1401}
1402
1403#[cfg(test)]
1408mod tests {
1409 use super::*;
1410
1411 fn create_test_stats() -> TableStats {
1412 let mut column_stats = HashMap::new();
1413 column_stats.insert(
1414 "id".to_string(),
1415 ColumnStats {
1416 name: "id".to_string(),
1417 distinct_count: 100000,
1418 null_count: 0,
1419 min_value: Some("1".to_string()),
1420 max_value: Some("100000".to_string()),
1421 avg_length: 8.0,
1422 mcv: vec![],
1423 histogram: None,
1424 },
1425 );
1426 column_stats.insert(
1427 "score".to_string(),
1428 ColumnStats {
1429 name: "score".to_string(),
1430 distinct_count: 100,
1431 null_count: 1000,
1432 min_value: Some("0".to_string()),
1433 max_value: Some("100".to_string()),
1434 avg_length: 8.0,
1435 mcv: vec![("50".to_string(), 0.05)],
1436 histogram: Some(Histogram {
1437 boundaries: vec![25.0, 50.0, 75.0, 100.0],
1438 counts: vec![25000, 25000, 25000, 25000],
1439 total_rows: 100000,
1440 }),
1441 },
1442 );
1443
1444 TableStats {
1445 name: "users".to_string(),
1446 row_count: 100000,
1447 size_bytes: 10_000_000, column_stats,
1449 indices: vec![
1450 IndexStats {
1451 name: "pk_users".to_string(),
1452 columns: vec!["id".to_string()],
1453 is_primary: true,
1454 is_unique: true,
1455 index_type: IndexType::BTree,
1456 leaf_pages: 1000,
1457 height: 3,
1458 avg_leaf_density: 100.0,
1459 },
1460 IndexStats {
1461 name: "idx_score".to_string(),
1462 columns: vec!["score".to_string()],
1463 is_primary: false,
1464 is_unique: false,
1465 index_type: IndexType::BTree,
1466 leaf_pages: 500,
1467 height: 2,
1468 avg_leaf_density: 200.0,
1469 },
1470 ],
1471 last_updated: 0,
1472 }
1473 }
1474
1475 #[test]
1476 fn test_selectivity_estimation() {
1477 let config = CostModelConfig::default();
1478 let optimizer = CostBasedOptimizer::new(config);
1479
1480 let stats = create_test_stats();
1481 optimizer.update_stats(stats.clone());
1482
1483 let pred = Predicate::Eq {
1485 column: "id".to_string(),
1486 value: "12345".to_string(),
1487 };
1488 let sel = optimizer.estimate_selectivity(&pred, &stats);
1489 assert!(sel < 0.001); let pred = Predicate::Gt {
1495 column: "score".to_string(),
1496 value: "75".to_string(),
1497 };
1498 let sel = optimizer.estimate_selectivity(&pred, &stats);
1499 assert!(sel > 0.4 && sel < 0.6); }
1501
1502 #[test]
1503 fn test_access_path_selection() {
1504 let config = CostModelConfig::default();
1505 let optimizer = CostBasedOptimizer::new(config);
1506
1507 let stats = create_test_stats();
1508 optimizer.update_stats(stats);
1509
1510 let pred = Predicate::Eq {
1512 column: "id".to_string(),
1513 value: "12345".to_string(),
1514 };
1515 let plan = optimizer.optimize(
1516 "users",
1517 vec!["id".to_string(), "score".to_string()],
1518 Some(pred),
1519 vec![],
1520 None,
1521 );
1522
1523 match plan {
1524 PhysicalPlan::IndexSeek { index, .. } => {
1525 assert_eq!(index, "pk_users");
1526 }
1527 _ => panic!("Expected IndexSeek for equality on primary key"),
1528 }
1529 }
1530
1531 #[test]
1532 fn test_token_budget_limit() {
1533 let config = CostModelConfig::default();
1534 let optimizer = CostBasedOptimizer::new(config).with_token_budget(2048, 25.0);
1535
1536 let plan = optimizer.optimize("users", vec!["id".to_string()], None, vec![], None);
1539
1540 match plan {
1541 PhysicalPlan::Limit { limit, .. } => {
1542 assert!(limit <= 80);
1543 }
1544 _ => panic!("Expected Limit to be injected"),
1545 }
1546 }
1547
1548 #[test]
1549 fn test_explain_output() {
1550 let config = CostModelConfig::default();
1551 let optimizer = CostBasedOptimizer::new(config);
1552
1553 let stats = create_test_stats();
1554 optimizer.update_stats(stats);
1555
1556 let plan = optimizer.optimize(
1557 "users",
1558 vec!["id".to_string(), "score".to_string()],
1559 Some(Predicate::Gt {
1560 column: "score".to_string(),
1561 value: "80".to_string(),
1562 }),
1563 vec![("score".to_string(), SortDirection::Descending)],
1564 Some(10),
1565 );
1566
1567 let explain = optimizer.explain(&plan);
1568 assert!(explain.contains("Limit"));
1569 assert!(explain.contains("Sort"));
1570 }
1571
1572 #[test]
1577 fn test_token_budget_underflow_safety() {
1578 let config = CostModelConfig::default();
1580 let optimizer = CostBasedOptimizer::new(config).with_token_budget(10, 25.0);
1581
1582 let plan = optimizer.optimize("users", vec!["id".to_string()], None, vec![], None);
1583 match plan {
1584 PhysicalPlan::Limit { limit, .. } => {
1585 assert!(limit >= 1, "Must return at least 1 row");
1586 }
1587 _ => panic!("Expected Limit"),
1588 }
1589 }
1590
1591 #[test]
1592 fn test_index_seek_derives_key_range() {
1593 let config = CostModelConfig::default();
1594 let optimizer = CostBasedOptimizer::new(config);
1595 optimizer.update_stats(create_test_stats());
1596
1597 let plan = optimizer.optimize(
1598 "users",
1599 vec!["id".to_string()],
1600 Some(Predicate::Eq {
1601 column: "id".to_string(),
1602 value: "42".to_string(),
1603 }),
1604 vec![],
1605 None,
1606 );
1607
1608 match plan {
1609 PhysicalPlan::IndexSeek { key_range, .. } => {
1610 assert!(key_range.start.is_some(), "KeyRange must derive from Eq predicate");
1611 assert_eq!(key_range.start, key_range.end, "Eq predicate → point key range");
1612 }
1613 _ => panic!("Expected IndexSeek"),
1614 }
1615 }
1616
1617 #[test]
1618 fn test_range_predicate_key_range() {
1619 let config = CostModelConfig::default();
1620 let optimizer = CostBasedOptimizer::new(config);
1621 optimizer.update_stats(create_test_stats());
1622
1623 let plan = optimizer.optimize(
1624 "users",
1625 vec!["score".to_string()],
1626 Some(Predicate::Between {
1627 column: "score".to_string(),
1628 min: "10".to_string(),
1629 max: "90".to_string(),
1630 }),
1631 vec![],
1632 None,
1633 );
1634
1635 match plan {
1636 PhysicalPlan::IndexSeek { key_range, .. } => {
1637 assert!(key_range.start.is_some());
1638 assert!(key_range.end.is_some());
1639 assert!(key_range.start_inclusive);
1640 assert!(key_range.end_inclusive);
1641 }
1642 _ => {} }
1644 }
1645
1646 #[test]
1647 fn test_projection_pushdown_proportional_reduction() {
1648 let config = CostModelConfig::default();
1649 let optimizer = CostBasedOptimizer::new(config);
1650 optimizer.update_stats(create_test_stats());
1651
1652 let plan_all = optimizer.optimize(
1654 "users",
1655 vec!["id".to_string(), "score".to_string()],
1656 None,
1657 vec![],
1658 Some(100),
1659 );
1660 let plan_single = optimizer.optimize(
1661 "users",
1662 vec!["id".to_string()],
1663 None,
1664 vec![],
1665 Some(100),
1666 );
1667
1668 let cost_all = optimizer.get_plan_cost(&plan_all);
1669 let cost_single = optimizer.get_plan_cost(&plan_single);
1670 assert!(cost_single <= cost_all, "Projection should reduce cost: {} vs {}", cost_single, cost_all);
1672 }
1673
1674 #[test]
1675 fn test_collect_stats_builds_histogram() {
1676 let config = CostModelConfig::default();
1677 let optimizer = CostBasedOptimizer::new(config);
1678
1679 let mut column_values = HashMap::new();
1680 let scores: Vec<String> = (0..100).map(|i| i.to_string()).collect();
1681 column_values.insert("score".to_string(), scores);
1682
1683 optimizer.collect_stats("test_table", 100, 10000, column_values, vec![]);
1684
1685 let stats = optimizer.get_stats("test_table").unwrap();
1686 assert_eq!(stats.row_count, 100);
1687 let score_stats = stats.column_stats.get("score").unwrap();
1688 assert_eq!(score_stats.distinct_count, 100);
1689 assert!(score_stats.histogram.is_some(), "Numeric column should get histogram");
1690 assert!(!score_stats.mcv.is_empty(), "Should build MCV list");
1691 }
1692
1693 #[test]
1694 fn test_plan_cache_invalidation() {
1695 let config = CostModelConfig::default();
1696 let optimizer = CostBasedOptimizer::new(config);
1697
1698 let mut col = HashMap::new();
1700 col.insert("x".to_string(), vec!["1".to_string()]);
1701 optimizer.collect_stats("t", 1, 100, col.clone(), vec![]);
1702
1703 assert!(optimizer.plan_cache.read().is_empty());
1705 }
1706
1707 #[test]
1708 fn test_stats_age_tracking() {
1709 let config = CostModelConfig::default();
1710 let optimizer = CostBasedOptimizer::new(config);
1711
1712 assert!(optimizer.stats_age_us("unknown").is_none());
1713
1714 let mut col = HashMap::new();
1715 col.insert("x".to_string(), vec!["1".to_string()]);
1716 optimizer.collect_stats("t", 1, 100, col, vec![]);
1717
1718 let age = optimizer.stats_age_us("t").unwrap();
1719 assert!(age < 1_000_000, "Stats should be fresh (< 1 second old)");
1720 }
1721
1722 #[test]
1723 fn test_scan_cost_reads_all_blocks() {
1724 let config = CostModelConfig::default();
1726 let optimizer = CostBasedOptimizer::new(config.clone());
1727 let no_pred = optimizer.estimate_scan_cost(1000, 4096 * 10, None);
1728 let with_pred = optimizer.estimate_scan_cost(
1729 1000,
1730 4096 * 10,
1731 Some(&Predicate::Eq {
1732 column: "x".to_string(),
1733 value: "1".to_string(),
1734 }),
1735 );
1736 assert!(
1739 (no_pred - with_pred).abs() < 0.001,
1740 "Scan cost should not depend on predicate: {} vs {}",
1741 no_pred,
1742 with_pred
1743 );
1744 }
1745
1746 #[test]
1747 fn test_index_wins_over_scan_for_point_lookup() {
1748 let config = CostModelConfig::default();
1749 let optimizer = CostBasedOptimizer::new(config);
1750 optimizer.update_stats(create_test_stats());
1751
1752 let scan_cost = optimizer.estimate_scan_cost(100000, 10_000_000, None);
1753
1754 let pk_index = &create_test_stats().indices[0]; let index_cost = optimizer.estimate_index_cost(pk_index, 100000, 0.00001);
1757
1758 assert!(
1759 index_cost < scan_cost * 0.1,
1760 "Index point lookup ({:.2}) should be <10% of scan cost ({:.2})",
1761 index_cost,
1762 scan_cost
1763 );
1764 }
1765
1766 #[test]
1767 fn test_no_stats_defaults_to_scan() {
1768 let config = CostModelConfig::default();
1769 let optimizer = CostBasedOptimizer::new(config);
1770 let plan = optimizer.optimize(
1772 "unknown_table",
1773 vec!["col1".to_string()],
1774 Some(Predicate::Eq {
1775 column: "col1".to_string(),
1776 value: "x".to_string(),
1777 }),
1778 vec![],
1779 None,
1780 );
1781 match plan {
1783 PhysicalPlan::TableScan { estimated_rows, .. } => {
1784 assert!(estimated_rows > 0, "Default row estimate must be positive");
1785 }
1786 PhysicalPlan::IndexSeek { .. } => {} _ => panic!("Expected TableScan or IndexSeek for unknown table"),
1788 }
1789 }
1790
1791 #[test]
1792 fn test_compound_predicate_selectivity() {
1793 let stats = create_test_stats();
1794 let config = CostModelConfig::default();
1795 let optimizer = CostBasedOptimizer::new(config);
1796
1797 let and_pred = Predicate::And(
1799 Box::new(Predicate::Eq {
1800 column: "id".to_string(),
1801 value: "1".to_string(),
1802 }),
1803 Box::new(Predicate::IsNotNull {
1804 column: "score".to_string(),
1805 }),
1806 );
1807 let sel = optimizer.estimate_selectivity(&and_pred, &stats);
1808 let eq_sel = optimizer.estimate_selectivity(
1809 &Predicate::Eq { column: "id".to_string(), value: "1".to_string() },
1810 &stats,
1811 );
1812 assert!(sel < eq_sel, "AND must be more selective than either child");
1813
1814 let or_pred = Predicate::Or(
1816 Box::new(Predicate::Eq {
1817 column: "id".to_string(),
1818 value: "1".to_string(),
1819 }),
1820 Box::new(Predicate::Eq {
1821 column: "id".to_string(),
1822 value: "2".to_string(),
1823 }),
1824 );
1825 let sel = optimizer.estimate_selectivity(&or_pred, &stats);
1826 assert!(sel > eq_sel, "OR must be less selective than either child");
1827 assert!(sel <= 1.0, "Selectivity must be <= 1.0");
1828 }
1829
1830 #[test]
1831 fn test_join_order_optimizer() {
1832 let mut join_opt = JoinOrderOptimizer::new(CostModelConfig::default());
1833 join_opt.add_stats(TableStats {
1834 name: "orders".to_string(),
1835 row_count: 1000000,
1836 size_bytes: 100_000_000,
1837 column_stats: HashMap::new(),
1838 indices: vec![],
1839 last_updated: 0,
1840 });
1841 join_opt.add_stats(TableStats {
1842 name: "users".to_string(),
1843 row_count: 10000,
1844 size_bytes: 1_000_000,
1845 column_stats: HashMap::new(),
1846 indices: vec![],
1847 last_updated: 0,
1848 });
1849
1850 let order = join_opt.find_optimal_order(
1851 &["orders".to_string(), "users".to_string()],
1852 &[("orders".to_string(), "user_id".to_string(), "users".to_string(), "id".to_string())],
1853 );
1854 assert!(!order.is_empty(), "Should find a join order");
1855 }
1856}