1use std::collections::{BTreeSet, HashMap, HashSet};
14use std::fmt;
15
16#[derive(Debug, Clone)]
22pub struct Relation {
23 pub name: String,
25 pub arity: usize,
27 pub estimated_rows: u64,
29 pub bound_columns: BTreeSet<usize>,
31}
32
33impl Relation {
34 pub fn new(name: impl Into<String>, arity: usize, estimated_rows: u64) -> Self {
36 Self {
37 name: name.into(),
38 arity,
39 estimated_rows,
40 bound_columns: BTreeSet::new(),
41 }
42 }
43
44 pub fn with_binding(mut self, col: usize) -> Self {
46 self.bound_columns.insert(col);
47 self
48 }
49
50 pub fn selectivity(&self) -> f64 {
52 if self.arity == 0 {
53 return 0.0;
54 }
55 self.bound_columns.len() as f64 / self.arity as f64
56 }
57}
58
59#[derive(Debug, Clone)]
65pub struct JoinCondition {
66 pub left_relation: String,
67 pub left_column: usize,
68 pub right_relation: String,
69 pub right_column: usize,
70}
71
72#[derive(Debug, Clone)]
78pub enum JoinPlanNode {
79 Scan {
81 relation: String,
82 estimated_cost: u64,
83 },
84 HashJoin {
86 left: Box<JoinPlanNode>,
87 right: Box<JoinPlanNode>,
88 conditions: Vec<JoinCondition>,
89 estimated_cost: u64,
90 estimated_rows: u64,
91 },
92 NestedLoopJoin {
94 left: Box<JoinPlanNode>,
95 right: Box<JoinPlanNode>,
96 conditions: Vec<JoinCondition>,
97 estimated_cost: u64,
98 estimated_rows: u64,
99 },
100}
101
102impl JoinPlanNode {
103 pub fn cost(&self) -> u64 {
105 match self {
106 Self::Scan { estimated_cost, .. } => *estimated_cost,
107 Self::HashJoin { estimated_cost, .. } => *estimated_cost,
108 Self::NestedLoopJoin { estimated_cost, .. } => *estimated_cost,
109 }
110 }
111
112 pub fn estimated_output_rows(&self) -> u64 {
114 match self {
115 Self::Scan { estimated_cost, .. } => *estimated_cost, Self::HashJoin { estimated_rows, .. } => *estimated_rows,
117 Self::NestedLoopJoin { estimated_rows, .. } => *estimated_rows,
118 }
119 }
120
121 pub fn depth(&self) -> usize {
123 match self {
124 Self::Scan { .. } => 1,
125 Self::HashJoin { left, right, .. } | Self::NestedLoopJoin { left, right, .. } => {
126 1 + left.depth().max(right.depth())
127 }
128 }
129 }
130
131 pub fn relations_involved(&self) -> Vec<String> {
133 let mut out = Vec::new();
134 self.collect_relations(&mut out);
135 out
136 }
137
138 fn collect_relations(&self, out: &mut Vec<String>) {
139 match self {
140 Self::Scan { relation, .. } => out.push(relation.clone()),
141 Self::HashJoin { left, right, .. } | Self::NestedLoopJoin { left, right, .. } => {
142 left.collect_relations(out);
143 right.collect_relations(out);
144 }
145 }
146 }
147
148 fn format_tree_inner(&self, indent: usize, buf: &mut String) {
150 let pad = " ".repeat(indent);
151 match self {
152 Self::Scan {
153 relation,
154 estimated_cost,
155 } => {
156 buf.push_str(&format!("{pad}Scan({relation}, cost={estimated_cost})\n"));
157 }
158 Self::HashJoin {
159 left,
160 right,
161 estimated_cost,
162 estimated_rows,
163 ..
164 } => {
165 buf.push_str(&format!(
166 "{pad}HashJoin(cost={estimated_cost}, rows={estimated_rows})\n"
167 ));
168 left.format_tree_inner(indent + 2, buf);
169 right.format_tree_inner(indent + 2, buf);
170 }
171 Self::NestedLoopJoin {
172 left,
173 right,
174 estimated_cost,
175 estimated_rows,
176 ..
177 } => {
178 buf.push_str(&format!(
179 "{pad}NestedLoopJoin(cost={estimated_cost}, rows={estimated_rows})\n"
180 ));
181 left.format_tree_inner(indent + 2, buf);
182 right.format_tree_inner(indent + 2, buf);
183 }
184 }
185 }
186
187 fn format_dot_inner(&self, counter: &mut usize, buf: &mut String) -> usize {
189 let id = *counter;
190 *counter += 1;
191 match self {
192 Self::Scan {
193 relation,
194 estimated_cost,
195 } => {
196 buf.push_str(&format!(
197 " n{id} [label=\"Scan({relation})\\ncost={estimated_cost}\"];\n"
198 ));
199 }
200 Self::HashJoin {
201 left,
202 right,
203 estimated_cost,
204 estimated_rows,
205 ..
206 } => {
207 buf.push_str(&format!(
208 " n{id} [label=\"HashJoin\\ncost={estimated_cost} rows={estimated_rows}\"];\n"
209 ));
210 let lid = left.format_dot_inner(counter, buf);
211 let rid = right.format_dot_inner(counter, buf);
212 buf.push_str(&format!(" n{id} -> n{lid};\n"));
213 buf.push_str(&format!(" n{id} -> n{rid};\n"));
214 }
215 Self::NestedLoopJoin {
216 left,
217 right,
218 estimated_cost,
219 estimated_rows,
220 ..
221 } => {
222 buf.push_str(&format!(
223 " n{id} [label=\"NLJoin\\ncost={estimated_cost} rows={estimated_rows}\"];\n"
224 ));
225 let lid = left.format_dot_inner(counter, buf);
226 let rid = right.format_dot_inner(counter, buf);
227 buf.push_str(&format!(" n{id} -> n{lid};\n"));
228 buf.push_str(&format!(" n{id} -> n{rid};\n"));
229 }
230 }
231 id
232 }
233}
234
235#[derive(Debug, Clone)]
241pub struct JoinStats {
242 pub relations_scanned: usize,
243 pub joins_performed: usize,
244 pub total_estimated_cost: u64,
245 pub total_estimated_rows: u64,
246 pub plan_depth: usize,
247}
248
249#[derive(Debug, Clone)]
255pub struct JoinPlan {
256 pub root: JoinPlanNode,
257 pub stats: JoinStats,
258}
259
260impl JoinPlan {
261 pub fn format_tree(&self) -> String {
263 let mut buf = String::new();
264 self.root.format_tree_inner(0, &mut buf);
265 buf
266 }
267
268 pub fn format_dot(&self) -> String {
270 let mut buf = String::from("digraph JoinPlan {\n");
271 let mut counter = 0usize;
272 self.root.format_dot_inner(&mut counter, &mut buf);
273 buf.push_str("}\n");
274 buf
275 }
276
277 pub fn total_cost(&self) -> u64 {
279 self.root.cost()
280 }
281}
282
283#[derive(Debug, Clone)]
289pub struct JoinOptimizerConfig {
290 pub max_relations: usize,
292 pub hash_join_threshold: u64,
294 pub default_selectivity: f64,
296 pub prefer_small_left: bool,
298}
299
300impl Default for JoinOptimizerConfig {
301 fn default() -> Self {
302 Self {
303 max_relations: 10,
304 hash_join_threshold: 100,
305 default_selectivity: 0.1,
306 prefer_small_left: true,
307 }
308 }
309}
310
311#[derive(Debug, Clone)]
317pub enum JoinOrderError {
318 NoRelations,
320 DisconnectedGraph(String),
322 TooManyRelations { count: usize, max: usize },
324 InvalidCondition(String),
326}
327
328impl fmt::Display for JoinOrderError {
329 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
330 match self {
331 Self::NoRelations => write!(f, "no relations provided for join ordering"),
332 Self::DisconnectedGraph(msg) => write!(f, "disconnected join graph: {msg}"),
333 Self::TooManyRelations { count, max } => {
334 write!(
335 f,
336 "too many relations ({count}) for exhaustive search (max {max})"
337 )
338 }
339 Self::InvalidCondition(msg) => write!(f, "invalid join condition: {msg}"),
340 }
341 }
342}
343
344impl std::error::Error for JoinOrderError {}
345
346fn subsets_of_size(n: usize, k: usize) -> Vec<BTreeSet<usize>> {
352 let mut result = Vec::new();
353 if k > n {
354 return result;
355 }
356 let mut indices: Vec<usize> = (0..k).collect();
357 loop {
358 result.push(indices.iter().copied().collect());
359 let mut i = k;
361 loop {
362 if i == 0 {
363 return result;
364 }
365 i -= 1;
366 if indices[i] != i + n - k {
367 break;
368 }
369 if i == 0 {
370 return result;
371 }
372 }
373 indices[i] += 1;
374 for j in (i + 1)..k {
375 indices[j] = indices[j - 1] + 1;
376 }
377 }
378}
379
380pub fn estimate_selectivity(
385 left_rows: u64,
386 right_rows: u64,
387 num_conditions: usize,
388 default_selectivity: f64,
389) -> f64 {
390 if num_conditions == 0 {
391 return 1.0; }
393 let max_side = left_rows.max(right_rows).max(1) as f64;
394 let per_cond = (1.0 / max_side).max(default_selectivity);
396 let sel = per_cond.powi(num_conditions as i32);
397 sel.clamp(f64::MIN_POSITIVE, 1.0)
398}
399
400pub struct JoinOrderOptimizer {
406 config: JoinOptimizerConfig,
407}
408
409impl JoinOrderOptimizer {
410 pub fn new(config: JoinOptimizerConfig) -> Self {
412 Self { config }
413 }
414
415 pub fn with_default() -> Self {
417 Self::new(JoinOptimizerConfig::default())
418 }
419
420 pub fn optimize(
425 &self,
426 relations: &[Relation],
427 conditions: &[JoinCondition],
428 ) -> Result<JoinPlan, JoinOrderError> {
429 if relations.is_empty() {
430 return Err(JoinOrderError::NoRelations);
431 }
432
433 let known: HashSet<&str> = relations.iter().map(|r| r.name.as_str()).collect();
435 for c in conditions {
436 if !known.contains(c.left_relation.as_str()) {
437 return Err(JoinOrderError::InvalidCondition(format!(
438 "unknown relation '{}'",
439 c.left_relation
440 )));
441 }
442 if !known.contains(c.right_relation.as_str()) {
443 return Err(JoinOrderError::InvalidCondition(format!(
444 "unknown relation '{}'",
445 c.right_relation
446 )));
447 }
448 }
449
450 let root = if relations.len() > self.config.max_relations {
451 self.greedy_order(relations, conditions)?
452 } else {
453 self.dp_order(relations, conditions)?
454 };
455
456 let rels = root.relations_involved();
457 let joins = if rels.len() > 1 { rels.len() - 1 } else { 0 };
458 let stats = JoinStats {
459 relations_scanned: rels.len(),
460 joins_performed: joins,
461 total_estimated_cost: root.cost(),
462 total_estimated_rows: root.estimated_output_rows(),
463 plan_depth: root.depth(),
464 };
465
466 Ok(JoinPlan { root, stats })
467 }
468
469 fn greedy_order(
471 &self,
472 relations: &[Relation],
473 conditions: &[JoinCondition],
474 ) -> Result<JoinPlanNode, JoinOrderError> {
475 if relations.len() == 1 {
476 let r = &relations[0];
477 return Ok(JoinPlanNode::Scan {
478 relation: r.name.clone(),
479 estimated_cost: r.estimated_rows,
480 });
481 }
482
483 let mut nodes: Vec<JoinPlanNode> = {
485 let mut v: Vec<_> = relations.iter().collect();
486 v.sort_by_key(|r| r.estimated_rows);
487 v.into_iter()
488 .map(|r| JoinPlanNode::Scan {
489 relation: r.name.clone(),
490 estimated_cost: r.estimated_rows,
491 })
492 .collect()
493 };
494
495 while nodes.len() > 1 {
496 let mut best_i = 0;
497 let mut best_j = 1;
498 let mut best_cost = u64::MAX;
499 let mut best_rows = u64::MAX;
500
501 for i in 0..nodes.len() {
502 for j in (i + 1)..nodes.len() {
503 let left_rels: HashSet<String> =
504 nodes[i].relations_involved().into_iter().collect();
505 let right_rels: HashSet<String> =
506 nodes[j].relations_involved().into_iter().collect();
507 let conds = Self::find_conditions(&left_rels, &right_rels, conditions);
508 let (cost, rows) = self.estimate_join_cost(&nodes[i], &nodes[j], &conds);
509 if cost < best_cost || (cost == best_cost && rows < best_rows) {
510 best_cost = cost;
511 best_rows = rows;
512 best_i = i;
513 best_j = j;
514 }
515 }
516 }
517
518 let right_node = nodes.remove(best_j);
520 let left_node = nodes.remove(best_i);
521
522 let left_rels: HashSet<String> = left_node.relations_involved().into_iter().collect();
523 let right_rels: HashSet<String> = right_node.relations_involved().into_iter().collect();
524 let conds = Self::find_conditions(&left_rels, &right_rels, conditions);
525 let (cost, rows) = self.estimate_join_cost(&left_node, &right_node, &conds);
526
527 let joined = self.make_join_node(left_node, right_node, conds, cost, rows);
528 nodes.push(joined);
529 }
530
531 Ok(nodes
533 .into_iter()
534 .next()
535 .unwrap_or_else(|| JoinPlanNode::Scan {
536 relation: String::new(),
537 estimated_cost: 0,
538 }))
539 }
540
541 fn dp_order(
545 &self,
546 relations: &[Relation],
547 conditions: &[JoinCondition],
548 ) -> Result<JoinPlanNode, JoinOrderError> {
549 let n = relations.len();
550 if n == 1 {
551 let r = &relations[0];
552 return Ok(JoinPlanNode::Scan {
553 relation: r.name.clone(),
554 estimated_cost: r.estimated_rows,
555 });
556 }
557
558 let idx_to_name: Vec<&str> = relations.iter().map(|r| r.name.as_str()).collect();
560
561 let mut dp: HashMap<BTreeSet<usize>, (JoinPlanNode, u64)> = HashMap::new();
563
564 for (i, r) in relations.iter().enumerate() {
566 let mut set = BTreeSet::new();
567 set.insert(i);
568 let node = JoinPlanNode::Scan {
569 relation: r.name.clone(),
570 estimated_cost: r.estimated_rows,
571 };
572 dp.insert(set, (node, r.estimated_rows));
573 }
574
575 for size in 2..=n {
577 let subsets = subsets_of_size(n, size);
578 for subset in &subsets {
579 let mut best: Option<(JoinPlanNode, u64)> = None;
580
581 let elems: Vec<usize> = subset.iter().copied().collect();
584 let m = elems.len();
585
586 for s1_size in 1..m {
587 let s1_subsets = subsets_of_size(m, s1_size);
588 for s1_indices in &s1_subsets {
589 let s1: BTreeSet<usize> =
590 s1_indices.iter().map(|&idx| elems[idx]).collect();
591 let s2: BTreeSet<usize> = subset.difference(&s1).copied().collect();
592
593 if s2.is_empty() {
594 continue;
595 }
596
597 let (left_plan, _left_cost) = match dp.get(&s1) {
598 Some(v) => v,
599 None => continue,
600 };
601 let (right_plan, _right_cost) = match dp.get(&s2) {
602 Some(v) => v,
603 None => continue,
604 };
605
606 let left_names: HashSet<String> =
608 s1.iter().map(|&i| idx_to_name[i].to_string()).collect();
609 let right_names: HashSet<String> =
610 s2.iter().map(|&i| idx_to_name[i].to_string()).collect();
611 let conds = Self::find_conditions(&left_names, &right_names, conditions);
612
613 let (cost, rows) = self.estimate_join_cost(left_plan, right_plan, &conds);
614
615 if best.as_ref().is_none_or(|(_, bc)| cost < *bc) {
616 let node = self.make_join_node(
617 left_plan.clone(),
618 right_plan.clone(),
619 conds,
620 cost,
621 rows,
622 );
623 best = Some((node, cost));
624 }
625 }
626 }
627
628 if let Some(entry) = best {
629 dp.insert(subset.clone(), entry);
630 }
631 }
632 }
633
634 let full: BTreeSet<usize> = (0..n).collect();
636 dp.remove(&full).map(|(node, _)| node).ok_or_else(|| {
637 JoinOrderError::DisconnectedGraph(
638 "could not find a plan covering all relations".to_string(),
639 )
640 })
641 }
642
643 fn estimate_join_cost(
645 &self,
646 left: &JoinPlanNode,
647 right: &JoinPlanNode,
648 conditions: &[JoinCondition],
649 ) -> (u64, u64) {
650 let left_rows = left.estimated_output_rows().max(1);
651 let right_rows = right.estimated_output_rows().max(1);
652
653 let selectivity = estimate_selectivity(
654 left_rows,
655 right_rows,
656 conditions.len(),
657 self.config.default_selectivity,
658 );
659
660 let output_rows =
661 ((left_rows as f64 * right_rows as f64 * selectivity).ceil() as u64).max(1);
662
663 let use_hash = right_rows > self.config.hash_join_threshold;
664 let join_cost = if use_hash {
665 left_rows + right_rows + output_rows
667 } else {
668 (left_rows.saturating_mul(right_rows)).max(left_rows + right_rows)
670 };
671
672 let total_cost = left
673 .cost()
674 .saturating_add(right.cost())
675 .saturating_add(join_cost);
676 (total_cost, output_rows)
677 }
678
679 fn find_conditions(
681 left_rels: &HashSet<String>,
682 right_rels: &HashSet<String>,
683 all_conditions: &[JoinCondition],
684 ) -> Vec<JoinCondition> {
685 all_conditions
686 .iter()
687 .filter(|c| {
688 (left_rels.contains(&c.left_relation) && right_rels.contains(&c.right_relation))
689 || (left_rels.contains(&c.right_relation)
690 && right_rels.contains(&c.left_relation))
691 })
692 .cloned()
693 .collect()
694 }
695
696 fn make_join_node(
698 &self,
699 left: JoinPlanNode,
700 right: JoinPlanNode,
701 conditions: Vec<JoinCondition>,
702 estimated_cost: u64,
703 estimated_rows: u64,
704 ) -> JoinPlanNode {
705 let right_rows = right.estimated_output_rows();
706 let use_hash = right_rows > self.config.hash_join_threshold;
707
708 let (left, right) = if self.config.prefer_small_left && use_hash {
709 if left.estimated_output_rows() > right.estimated_output_rows() {
710 (right, left)
711 } else {
712 (left, right)
713 }
714 } else {
715 (left, right)
716 };
717
718 if use_hash {
719 JoinPlanNode::HashJoin {
720 left: Box::new(left),
721 right: Box::new(right),
722 conditions,
723 estimated_cost,
724 estimated_rows,
725 }
726 } else {
727 JoinPlanNode::NestedLoopJoin {
728 left: Box::new(left),
729 right: Box::new(right),
730 conditions,
731 estimated_cost,
732 estimated_rows,
733 }
734 }
735 }
736}
737
738impl Default for JoinOrderOptimizer {
739 fn default() -> Self {
740 Self::with_default()
741 }
742}
743
744#[cfg(test)]
749mod tests {
750 use super::*;
751
752 #[test]
753 fn test_relation_new() {
754 let r = Relation::new("users", 3, 1000);
755 assert_eq!(r.name, "users");
756 assert_eq!(r.arity, 3);
757 assert_eq!(r.estimated_rows, 1000);
758 assert!(r.bound_columns.is_empty());
759 }
760
761 #[test]
762 fn test_relation_with_binding() {
763 let r = Relation::new("users", 3, 1000)
764 .with_binding(0)
765 .with_binding(2);
766 assert!(r.bound_columns.contains(&0));
767 assert!(r.bound_columns.contains(&2));
768 assert!(!r.bound_columns.contains(&1));
769 assert_eq!(r.bound_columns.len(), 2);
770 }
771
772 #[test]
773 fn test_relation_selectivity() {
774 let r = Relation::new("users", 4, 1000)
775 .with_binding(0)
776 .with_binding(1);
777 let sel = r.selectivity();
778 assert!((sel - 0.5).abs() < 1e-10);
779
780 let r_zero = Relation::new("empty", 0, 0);
781 assert!((r_zero.selectivity() - 0.0).abs() < 1e-10);
782 }
783
784 #[test]
785 fn test_join_config_default() {
786 let cfg = JoinOptimizerConfig::default();
787 assert_eq!(cfg.max_relations, 10);
788 assert_eq!(cfg.hash_join_threshold, 100);
789 assert!((cfg.default_selectivity - 0.1).abs() < 1e-10);
790 assert!(cfg.prefer_small_left);
791 }
792
793 #[test]
794 fn test_greedy_single_relation() {
795 let opt = JoinOrderOptimizer::with_default();
796 let rels = vec![Relation::new("users", 3, 100)];
797 let plan = opt.optimize(&rels, &[]).expect("should succeed");
798 assert!(matches!(plan.root, JoinPlanNode::Scan { .. }));
799 assert_eq!(plan.stats.relations_scanned, 1);
800 assert_eq!(plan.stats.joins_performed, 0);
801 }
802
803 #[test]
804 fn test_greedy_two_relations() {
805 let opt = JoinOrderOptimizer::with_default();
806 let rels = vec![
807 Relation::new("users", 2, 500),
808 Relation::new("orders", 3, 2000),
809 ];
810 let conds = vec![JoinCondition {
811 left_relation: "users".to_string(),
812 left_column: 0,
813 right_relation: "orders".to_string(),
814 right_column: 1,
815 }];
816 let plan = opt.optimize(&rels, &conds).expect("should succeed");
817 assert_eq!(plan.stats.relations_scanned, 2);
818 assert_eq!(plan.stats.joins_performed, 1);
819 assert!(plan.root.cost() > 0);
820 }
821
822 #[test]
823 fn test_greedy_three_relations() {
824 let opt = JoinOrderOptimizer::with_default();
825 let rels = vec![
826 Relation::new("a", 2, 100),
827 Relation::new("b", 2, 200),
828 Relation::new("c", 2, 300),
829 ];
830 let conds = vec![
831 JoinCondition {
832 left_relation: "a".to_string(),
833 left_column: 0,
834 right_relation: "b".to_string(),
835 right_column: 0,
836 },
837 JoinCondition {
838 left_relation: "b".to_string(),
839 left_column: 1,
840 right_relation: "c".to_string(),
841 right_column: 0,
842 },
843 ];
844 let plan = opt.optimize(&rels, &conds).expect("should succeed");
845 assert_eq!(plan.stats.relations_scanned, 3);
846 assert_eq!(plan.stats.joins_performed, 2);
847 assert!(plan.root.depth() >= 2);
848 }
849
850 #[test]
851 fn test_dp_two_relations() {
852 let opt = JoinOrderOptimizer::with_default();
853 let rels = vec![Relation::new("x", 2, 50), Relation::new("y", 2, 80)];
854 let conds = vec![JoinCondition {
855 left_relation: "x".to_string(),
856 left_column: 0,
857 right_relation: "y".to_string(),
858 right_column: 0,
859 }];
860 let plan = opt.optimize(&rels, &conds).expect("should succeed");
861 assert_eq!(plan.stats.relations_scanned, 2);
862 assert_eq!(plan.stats.joins_performed, 1);
863 }
864
865 #[test]
866 fn test_dp_three_relations() {
867 let opt = JoinOrderOptimizer::with_default();
868 let rels = vec![
869 Relation::new("r1", 2, 10),
870 Relation::new("r2", 2, 20),
871 Relation::new("r3", 2, 30),
872 ];
873 let conds = vec![
874 JoinCondition {
875 left_relation: "r1".to_string(),
876 left_column: 0,
877 right_relation: "r2".to_string(),
878 right_column: 0,
879 },
880 JoinCondition {
881 left_relation: "r2".to_string(),
882 left_column: 1,
883 right_relation: "r3".to_string(),
884 right_column: 0,
885 },
886 ];
887 let plan = opt.optimize(&rels, &conds).expect("should succeed");
888 assert_eq!(plan.stats.relations_scanned, 3);
889 assert_eq!(plan.stats.joins_performed, 2);
890 assert!(plan.root.depth() >= 2);
891 }
892
893 #[test]
894 fn test_optimize_uses_greedy_when_too_many() {
895 let cfg = JoinOptimizerConfig {
896 max_relations: 2,
897 ..Default::default()
898 };
899 let opt = JoinOrderOptimizer::new(cfg);
900 let rels = vec![
901 Relation::new("a", 2, 10),
902 Relation::new("b", 2, 20),
903 Relation::new("c", 2, 30),
904 ];
905 let conds = vec![
906 JoinCondition {
907 left_relation: "a".to_string(),
908 left_column: 0,
909 right_relation: "b".to_string(),
910 right_column: 0,
911 },
912 JoinCondition {
913 left_relation: "b".to_string(),
914 left_column: 1,
915 right_relation: "c".to_string(),
916 right_column: 0,
917 },
918 ];
919 let plan = opt.optimize(&rels, &conds).expect("greedy fallback");
921 assert_eq!(plan.stats.relations_scanned, 3);
922 }
923
924 #[test]
925 fn test_optimize_no_relations_error() {
926 let opt = JoinOrderOptimizer::with_default();
927 let result = opt.optimize(&[], &[]);
928 assert!(result.is_err());
929 assert!(matches!(result, Err(JoinOrderError::NoRelations)));
930 }
931
932 #[test]
933 fn test_join_plan_node_cost() {
934 let node = JoinPlanNode::Scan {
935 relation: "t".to_string(),
936 estimated_cost: 42,
937 };
938 assert_eq!(node.cost(), 42);
939 assert!(node.cost() > 0);
940 }
941
942 #[test]
943 fn test_join_plan_node_depth() {
944 let leaf = JoinPlanNode::Scan {
945 relation: "t".to_string(),
946 estimated_cost: 10,
947 };
948 assert_eq!(leaf.depth(), 1);
949
950 let join = JoinPlanNode::HashJoin {
951 left: Box::new(JoinPlanNode::Scan {
952 relation: "a".to_string(),
953 estimated_cost: 5,
954 }),
955 right: Box::new(JoinPlanNode::Scan {
956 relation: "b".to_string(),
957 estimated_cost: 10,
958 }),
959 conditions: vec![],
960 estimated_cost: 20,
961 estimated_rows: 8,
962 };
963 assert_eq!(join.depth(), 2);
964 }
965
966 #[test]
967 fn test_join_plan_node_relations() {
968 let join = JoinPlanNode::HashJoin {
969 left: Box::new(JoinPlanNode::Scan {
970 relation: "a".to_string(),
971 estimated_cost: 5,
972 }),
973 right: Box::new(JoinPlanNode::Scan {
974 relation: "b".to_string(),
975 estimated_cost: 10,
976 }),
977 conditions: vec![],
978 estimated_cost: 20,
979 estimated_rows: 8,
980 };
981 let rels = join.relations_involved();
982 assert!(rels.contains(&"a".to_string()));
983 assert!(rels.contains(&"b".to_string()));
984 assert_eq!(rels.len(), 2);
985 }
986
987 #[test]
988 fn test_join_plan_format_tree() {
989 let opt = JoinOrderOptimizer::with_default();
990 let rels = vec![Relation::new("a", 2, 100), Relation::new("b", 2, 200)];
991 let conds = vec![JoinCondition {
992 left_relation: "a".to_string(),
993 left_column: 0,
994 right_relation: "b".to_string(),
995 right_column: 0,
996 }];
997 let plan = opt.optimize(&rels, &conds).expect("ok");
998 let tree = plan.format_tree();
999 assert!(!tree.is_empty());
1000 }
1001
1002 #[test]
1003 fn test_join_plan_format_dot() {
1004 let opt = JoinOrderOptimizer::with_default();
1005 let rels = vec![Relation::new("a", 2, 100), Relation::new("b", 2, 200)];
1006 let conds = vec![JoinCondition {
1007 left_relation: "a".to_string(),
1008 left_column: 0,
1009 right_relation: "b".to_string(),
1010 right_column: 0,
1011 }];
1012 let plan = opt.optimize(&rels, &conds).expect("ok");
1013 let dot = plan.format_dot();
1014 assert!(dot.contains("digraph"));
1015 }
1016
1017 #[test]
1018 fn test_estimate_selectivity() {
1019 let sel = estimate_selectivity(1000, 2000, 1, 0.1);
1020 assert!(sel > 0.0);
1021 assert!(sel <= 1.0);
1022
1023 let sel_cross = estimate_selectivity(100, 100, 0, 0.1);
1025 assert!((sel_cross - 1.0).abs() < 1e-10);
1026
1027 let sel_one = estimate_selectivity(100, 200, 1, 0.1);
1029 let sel_two = estimate_selectivity(100, 200, 2, 0.1);
1030 assert!(sel_two < sel_one);
1031 }
1032
1033 #[test]
1034 fn test_find_conditions() {
1035 let conds = vec![
1036 JoinCondition {
1037 left_relation: "a".to_string(),
1038 left_column: 0,
1039 right_relation: "b".to_string(),
1040 right_column: 0,
1041 },
1042 JoinCondition {
1043 left_relation: "b".to_string(),
1044 left_column: 1,
1045 right_relation: "c".to_string(),
1046 right_column: 0,
1047 },
1048 ];
1049
1050 let left: HashSet<String> = ["a".to_string()].into_iter().collect();
1051 let right: HashSet<String> = ["b".to_string()].into_iter().collect();
1052 let found = JoinOrderOptimizer::find_conditions(&left, &right, &conds);
1053 assert_eq!(found.len(), 1);
1054 assert_eq!(found[0].left_relation, "a");
1055
1056 let left2: HashSet<String> = ["a".to_string()].into_iter().collect();
1057 let right2: HashSet<String> = ["c".to_string()].into_iter().collect();
1058 let found2 = JoinOrderOptimizer::find_conditions(&left2, &right2, &conds);
1059 assert_eq!(found2.len(), 0);
1060 }
1061
1062 #[test]
1063 fn test_join_stats() {
1064 let opt = JoinOrderOptimizer::with_default();
1065 let rels = vec![
1066 Relation::new("a", 2, 100),
1067 Relation::new("b", 2, 200),
1068 Relation::new("c", 2, 300),
1069 ];
1070 let conds = vec![
1071 JoinCondition {
1072 left_relation: "a".to_string(),
1073 left_column: 0,
1074 right_relation: "b".to_string(),
1075 right_column: 0,
1076 },
1077 JoinCondition {
1078 left_relation: "b".to_string(),
1079 left_column: 1,
1080 right_relation: "c".to_string(),
1081 right_column: 0,
1082 },
1083 ];
1084 let plan = opt.optimize(&rels, &conds).expect("ok");
1085 assert_eq!(plan.stats.relations_scanned, 3);
1086 assert_eq!(plan.stats.joins_performed, 2);
1087 assert!(plan.stats.total_estimated_cost > 0);
1088 assert!(plan.stats.total_estimated_rows > 0);
1089 assert!(plan.stats.plan_depth >= 2);
1090 }
1091
1092 #[test]
1093 fn test_join_order_error_display() {
1094 let e1 = JoinOrderError::NoRelations;
1095 assert!(!e1.to_string().is_empty());
1096
1097 let e2 = JoinOrderError::DisconnectedGraph("parts missing".to_string());
1098 assert!(e2.to_string().contains("disconnected"));
1099
1100 let e3 = JoinOrderError::TooManyRelations { count: 20, max: 10 };
1101 assert!(e3.to_string().contains("20"));
1102
1103 let e4 = JoinOrderError::InvalidCondition("bad ref".to_string());
1104 assert!(e4.to_string().contains("invalid"));
1105 }
1106}