Skip to main content

tensorlogic_infer/
join_order.rs

1//! Join ordering optimization for logic queries.
2//!
3//! Given a set of predicates/atoms to be joined, finds the cheapest join order
4//! based on selectivity estimates. This is a classic Datalog/DB query optimization
5//! problem relevant to TensorLogic's inference engine.
6//!
7//! ## Strategies
8//!
9//! - **Dynamic Programming (System R style)**: For small numbers of relations
10//!   (≤ `max_relations`), enumerates all subset partitions to find the optimal plan.
11//! - **Greedy**: For larger queries, repeatedly picks the cheapest next join. O(n²).
12
13use std::collections::{BTreeSet, HashMap, HashSet};
14use std::fmt;
15
16// ---------------------------------------------------------------------------
17// Relation
18// ---------------------------------------------------------------------------
19
20/// A relation (predicate) with arity and estimated cardinality.
21#[derive(Debug, Clone)]
22pub struct Relation {
23    /// Name of the relation / predicate.
24    pub name: String,
25    /// Number of columns (arity).
26    pub arity: usize,
27    /// Estimated number of rows (tuples).
28    pub estimated_rows: u64,
29    /// Which columns have known bindings (constants / previously bound variables).
30    pub bound_columns: BTreeSet<usize>,
31}
32
33impl Relation {
34    /// Create a new relation with no bound columns.
35    pub fn new(name: impl Into<String>, arity: usize, estimated_rows: u64) -> Self {
36        Self {
37            name: name.into(),
38            arity,
39            estimated_rows,
40            bound_columns: BTreeSet::new(),
41        }
42    }
43
44    /// Mark a column as bound and return self (builder pattern).
45    pub fn with_binding(mut self, col: usize) -> Self {
46        self.bound_columns.insert(col);
47        self
48    }
49
50    /// Fraction of columns that are bound. Returns 0.0 when arity is 0.
51    pub fn selectivity(&self) -> f64 {
52        if self.arity == 0 {
53            return 0.0;
54        }
55        self.bound_columns.len() as f64 / self.arity as f64
56    }
57}
58
59// ---------------------------------------------------------------------------
60// JoinCondition
61// ---------------------------------------------------------------------------
62
63/// A join condition between two relations (equi-join on one column each).
64#[derive(Debug, Clone)]
65pub struct JoinCondition {
66    pub left_relation: String,
67    pub left_column: usize,
68    pub right_relation: String,
69    pub right_column: usize,
70}
71
72// ---------------------------------------------------------------------------
73// JoinPlanNode
74// ---------------------------------------------------------------------------
75
76/// A node in a join plan tree.
77#[derive(Debug, Clone)]
78pub enum JoinPlanNode {
79    /// Leaf: scan a single relation.
80    Scan {
81        relation: String,
82        estimated_cost: u64,
83    },
84    /// Hash join (good for larger inner relations).
85    HashJoin {
86        left: Box<JoinPlanNode>,
87        right: Box<JoinPlanNode>,
88        conditions: Vec<JoinCondition>,
89        estimated_cost: u64,
90        estimated_rows: u64,
91    },
92    /// Nested-loop join (fallback / small inner).
93    NestedLoopJoin {
94        left: Box<JoinPlanNode>,
95        right: Box<JoinPlanNode>,
96        conditions: Vec<JoinCondition>,
97        estimated_cost: u64,
98        estimated_rows: u64,
99    },
100}
101
102impl JoinPlanNode {
103    /// Total estimated cost of this sub-plan.
104    pub fn cost(&self) -> u64 {
105        match self {
106            Self::Scan { estimated_cost, .. } => *estimated_cost,
107            Self::HashJoin { estimated_cost, .. } => *estimated_cost,
108            Self::NestedLoopJoin { estimated_cost, .. } => *estimated_cost,
109        }
110    }
111
112    /// Estimated number of output rows.
113    pub fn estimated_output_rows(&self) -> u64 {
114        match self {
115            Self::Scan { estimated_cost, .. } => *estimated_cost, // rows == cost for scans
116            Self::HashJoin { estimated_rows, .. } => *estimated_rows,
117            Self::NestedLoopJoin { estimated_rows, .. } => *estimated_rows,
118        }
119    }
120
121    /// Depth of the plan tree (leaf = 1).
122    pub fn depth(&self) -> usize {
123        match self {
124            Self::Scan { .. } => 1,
125            Self::HashJoin { left, right, .. } | Self::NestedLoopJoin { left, right, .. } => {
126                1 + left.depth().max(right.depth())
127            }
128        }
129    }
130
131    /// Collect all relation names involved in this sub-plan.
132    pub fn relations_involved(&self) -> Vec<String> {
133        let mut out = Vec::new();
134        self.collect_relations(&mut out);
135        out
136    }
137
138    fn collect_relations(&self, out: &mut Vec<String>) {
139        match self {
140            Self::Scan { relation, .. } => out.push(relation.clone()),
141            Self::HashJoin { left, right, .. } | Self::NestedLoopJoin { left, right, .. } => {
142                left.collect_relations(out);
143                right.collect_relations(out);
144            }
145        }
146    }
147
148    /// Recursive helper for `format_tree`.
149    fn format_tree_inner(&self, indent: usize, buf: &mut String) {
150        let pad = " ".repeat(indent);
151        match self {
152            Self::Scan {
153                relation,
154                estimated_cost,
155            } => {
156                buf.push_str(&format!("{pad}Scan({relation}, cost={estimated_cost})\n"));
157            }
158            Self::HashJoin {
159                left,
160                right,
161                estimated_cost,
162                estimated_rows,
163                ..
164            } => {
165                buf.push_str(&format!(
166                    "{pad}HashJoin(cost={estimated_cost}, rows={estimated_rows})\n"
167                ));
168                left.format_tree_inner(indent + 2, buf);
169                right.format_tree_inner(indent + 2, buf);
170            }
171            Self::NestedLoopJoin {
172                left,
173                right,
174                estimated_cost,
175                estimated_rows,
176                ..
177            } => {
178                buf.push_str(&format!(
179                    "{pad}NestedLoopJoin(cost={estimated_cost}, rows={estimated_rows})\n"
180                ));
181                left.format_tree_inner(indent + 2, buf);
182                right.format_tree_inner(indent + 2, buf);
183            }
184        }
185    }
186
187    /// Recursive DOT helper. Returns the node id assigned.
188    fn format_dot_inner(&self, counter: &mut usize, buf: &mut String) -> usize {
189        let id = *counter;
190        *counter += 1;
191        match self {
192            Self::Scan {
193                relation,
194                estimated_cost,
195            } => {
196                buf.push_str(&format!(
197                    "  n{id} [label=\"Scan({relation})\\ncost={estimated_cost}\"];\n"
198                ));
199            }
200            Self::HashJoin {
201                left,
202                right,
203                estimated_cost,
204                estimated_rows,
205                ..
206            } => {
207                buf.push_str(&format!(
208                    "  n{id} [label=\"HashJoin\\ncost={estimated_cost} rows={estimated_rows}\"];\n"
209                ));
210                let lid = left.format_dot_inner(counter, buf);
211                let rid = right.format_dot_inner(counter, buf);
212                buf.push_str(&format!("  n{id} -> n{lid};\n"));
213                buf.push_str(&format!("  n{id} -> n{rid};\n"));
214            }
215            Self::NestedLoopJoin {
216                left,
217                right,
218                estimated_cost,
219                estimated_rows,
220                ..
221            } => {
222                buf.push_str(&format!(
223                    "  n{id} [label=\"NLJoin\\ncost={estimated_cost} rows={estimated_rows}\"];\n"
224                ));
225                let lid = left.format_dot_inner(counter, buf);
226                let rid = right.format_dot_inner(counter, buf);
227                buf.push_str(&format!("  n{id} -> n{lid};\n"));
228                buf.push_str(&format!("  n{id} -> n{rid};\n"));
229            }
230        }
231        id
232    }
233}
234
235// ---------------------------------------------------------------------------
236// JoinStats
237// ---------------------------------------------------------------------------
238
239/// Aggregate statistics for a join plan.
240#[derive(Debug, Clone)]
241pub struct JoinStats {
242    pub relations_scanned: usize,
243    pub joins_performed: usize,
244    pub total_estimated_cost: u64,
245    pub total_estimated_rows: u64,
246    pub plan_depth: usize,
247}
248
249// ---------------------------------------------------------------------------
250// JoinPlan
251// ---------------------------------------------------------------------------
252
253/// A complete join plan with root node and statistics.
254#[derive(Debug, Clone)]
255pub struct JoinPlan {
256    pub root: JoinPlanNode,
257    pub stats: JoinStats,
258}
259
260impl JoinPlan {
261    /// Indented tree representation.
262    pub fn format_tree(&self) -> String {
263        let mut buf = String::new();
264        self.root.format_tree_inner(0, &mut buf);
265        buf
266    }
267
268    /// DOT graph representation.
269    pub fn format_dot(&self) -> String {
270        let mut buf = String::from("digraph JoinPlan {\n");
271        let mut counter = 0usize;
272        self.root.format_dot_inner(&mut counter, &mut buf);
273        buf.push_str("}\n");
274        buf
275    }
276
277    /// Total estimated cost of the plan.
278    pub fn total_cost(&self) -> u64 {
279        self.root.cost()
280    }
281}
282
283// ---------------------------------------------------------------------------
284// JoinOptimizerConfig
285// ---------------------------------------------------------------------------
286
287/// Configuration for the join order optimizer.
288#[derive(Debug, Clone)]
289pub struct JoinOptimizerConfig {
290    /// Above this count the optimizer falls back to greedy.
291    pub max_relations: usize,
292    /// Use hash join when the inner relation exceeds this many rows.
293    pub hash_join_threshold: u64,
294    /// Assumed selectivity when unknown.
295    pub default_selectivity: f64,
296    /// When true, place the smaller relation on the left (build side) in hash joins.
297    pub prefer_small_left: bool,
298}
299
300impl Default for JoinOptimizerConfig {
301    fn default() -> Self {
302        Self {
303            max_relations: 10,
304            hash_join_threshold: 100,
305            default_selectivity: 0.1,
306            prefer_small_left: true,
307        }
308    }
309}
310
311// ---------------------------------------------------------------------------
312// JoinOrderError
313// ---------------------------------------------------------------------------
314
315/// Errors from the join optimizer.
316#[derive(Debug, Clone)]
317pub enum JoinOrderError {
318    /// No relations were provided.
319    NoRelations,
320    /// The join graph is disconnected.
321    DisconnectedGraph(String),
322    /// Too many relations for the requested strategy.
323    TooManyRelations { count: usize, max: usize },
324    /// A join condition references a relation not in the input set.
325    InvalidCondition(String),
326}
327
328impl fmt::Display for JoinOrderError {
329    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
330        match self {
331            Self::NoRelations => write!(f, "no relations provided for join ordering"),
332            Self::DisconnectedGraph(msg) => write!(f, "disconnected join graph: {msg}"),
333            Self::TooManyRelations { count, max } => {
334                write!(
335                    f,
336                    "too many relations ({count}) for exhaustive search (max {max})"
337                )
338            }
339            Self::InvalidCondition(msg) => write!(f, "invalid join condition: {msg}"),
340        }
341    }
342}
343
344impl std::error::Error for JoinOrderError {}
345
346// ---------------------------------------------------------------------------
347// Helpers
348// ---------------------------------------------------------------------------
349
350/// Enumerate all subsets of size `k` from `{0..n-1}`.
351fn subsets_of_size(n: usize, k: usize) -> Vec<BTreeSet<usize>> {
352    let mut result = Vec::new();
353    if k > n {
354        return result;
355    }
356    let mut indices: Vec<usize> = (0..k).collect();
357    loop {
358        result.push(indices.iter().copied().collect());
359        // Advance to next combination
360        let mut i = k;
361        loop {
362            if i == 0 {
363                return result;
364            }
365            i -= 1;
366            if indices[i] != i + n - k {
367                break;
368            }
369            if i == 0 {
370                return result;
371            }
372        }
373        indices[i] += 1;
374        for j in (i + 1)..k {
375            indices[j] = indices[j - 1] + 1;
376        }
377    }
378}
379
380/// Selectivity estimation helper.
381///
382/// Returns a value in `(0, 1]` estimating the fraction of the cross-product
383/// that survives the join.
384pub fn estimate_selectivity(
385    left_rows: u64,
386    right_rows: u64,
387    num_conditions: usize,
388    default_selectivity: f64,
389) -> f64 {
390    if num_conditions == 0 {
391        return 1.0; // cross product
392    }
393    let max_side = left_rows.max(right_rows).max(1) as f64;
394    // Each condition filters by roughly 1/max(|L|,|R|), capped by default_selectivity.
395    let per_cond = (1.0 / max_side).max(default_selectivity);
396    let sel = per_cond.powi(num_conditions as i32);
397    sel.clamp(f64::MIN_POSITIVE, 1.0)
398}
399
400// ---------------------------------------------------------------------------
401// JoinOrderOptimizer
402// ---------------------------------------------------------------------------
403
404/// The join order optimizer.
405pub struct JoinOrderOptimizer {
406    config: JoinOptimizerConfig,
407}
408
409impl JoinOrderOptimizer {
410    /// Create with explicit configuration.
411    pub fn new(config: JoinOptimizerConfig) -> Self {
412        Self { config }
413    }
414
415    /// Create with default configuration.
416    pub fn with_default() -> Self {
417        Self::new(JoinOptimizerConfig::default())
418    }
419
420    /// Find optimal join order for a set of relations with join conditions.
421    ///
422    /// For `<= max_relations`: dynamic programming (System R style).
423    /// For `> max_relations`: greedy (smallest estimated output first).
424    pub fn optimize(
425        &self,
426        relations: &[Relation],
427        conditions: &[JoinCondition],
428    ) -> Result<JoinPlan, JoinOrderError> {
429        if relations.is_empty() {
430            return Err(JoinOrderError::NoRelations);
431        }
432
433        // Validate conditions reference known relations
434        let known: HashSet<&str> = relations.iter().map(|r| r.name.as_str()).collect();
435        for c in conditions {
436            if !known.contains(c.left_relation.as_str()) {
437                return Err(JoinOrderError::InvalidCondition(format!(
438                    "unknown relation '{}'",
439                    c.left_relation
440                )));
441            }
442            if !known.contains(c.right_relation.as_str()) {
443                return Err(JoinOrderError::InvalidCondition(format!(
444                    "unknown relation '{}'",
445                    c.right_relation
446                )));
447            }
448        }
449
450        let root = if relations.len() > self.config.max_relations {
451            self.greedy_order(relations, conditions)?
452        } else {
453            self.dp_order(relations, conditions)?
454        };
455
456        let rels = root.relations_involved();
457        let joins = if rels.len() > 1 { rels.len() - 1 } else { 0 };
458        let stats = JoinStats {
459            relations_scanned: rels.len(),
460            joins_performed: joins,
461            total_estimated_cost: root.cost(),
462            total_estimated_rows: root.estimated_output_rows(),
463            plan_depth: root.depth(),
464        };
465
466        Ok(JoinPlan { root, stats })
467    }
468
469    /// Greedy join ordering: always pick the cheapest next join.
470    fn greedy_order(
471        &self,
472        relations: &[Relation],
473        conditions: &[JoinCondition],
474    ) -> Result<JoinPlanNode, JoinOrderError> {
475        if relations.len() == 1 {
476            let r = &relations[0];
477            return Ok(JoinPlanNode::Scan {
478                relation: r.name.clone(),
479                estimated_cost: r.estimated_rows,
480            });
481        }
482
483        // Build initial nodes sorted by estimated_rows ascending.
484        let mut nodes: Vec<JoinPlanNode> = {
485            let mut v: Vec<_> = relations.iter().collect();
486            v.sort_by_key(|r| r.estimated_rows);
487            v.into_iter()
488                .map(|r| JoinPlanNode::Scan {
489                    relation: r.name.clone(),
490                    estimated_cost: r.estimated_rows,
491                })
492                .collect()
493        };
494
495        while nodes.len() > 1 {
496            let mut best_i = 0;
497            let mut best_j = 1;
498            let mut best_cost = u64::MAX;
499            let mut best_rows = u64::MAX;
500
501            for i in 0..nodes.len() {
502                for j in (i + 1)..nodes.len() {
503                    let left_rels: HashSet<String> =
504                        nodes[i].relations_involved().into_iter().collect();
505                    let right_rels: HashSet<String> =
506                        nodes[j].relations_involved().into_iter().collect();
507                    let conds = Self::find_conditions(&left_rels, &right_rels, conditions);
508                    let (cost, rows) = self.estimate_join_cost(&nodes[i], &nodes[j], &conds);
509                    if cost < best_cost || (cost == best_cost && rows < best_rows) {
510                        best_cost = cost;
511                        best_rows = rows;
512                        best_i = i;
513                        best_j = j;
514                    }
515                }
516            }
517
518            // Remove j first (larger index) then i.
519            let right_node = nodes.remove(best_j);
520            let left_node = nodes.remove(best_i);
521
522            let left_rels: HashSet<String> = left_node.relations_involved().into_iter().collect();
523            let right_rels: HashSet<String> = right_node.relations_involved().into_iter().collect();
524            let conds = Self::find_conditions(&left_rels, &right_rels, conditions);
525            let (cost, rows) = self.estimate_join_cost(&left_node, &right_node, &conds);
526
527            let joined = self.make_join_node(left_node, right_node, conds, cost, rows);
528            nodes.push(joined);
529        }
530
531        // Safety: we checked len >= 1 above and the loop leaves exactly 1 element.
532        Ok(nodes
533            .into_iter()
534            .next()
535            .unwrap_or_else(|| JoinPlanNode::Scan {
536                relation: String::new(),
537                estimated_cost: 0,
538            }))
539    }
540
541    /// Dynamic programming (System R style) for small number of relations.
542    ///
543    /// `dp[S]` = best plan for subset S of relations.
544    fn dp_order(
545        &self,
546        relations: &[Relation],
547        conditions: &[JoinCondition],
548    ) -> Result<JoinPlanNode, JoinOrderError> {
549        let n = relations.len();
550        if n == 1 {
551            let r = &relations[0];
552            return Ok(JoinPlanNode::Scan {
553                relation: r.name.clone(),
554                estimated_cost: r.estimated_rows,
555            });
556        }
557
558        // Map index → relation name.
559        let idx_to_name: Vec<&str> = relations.iter().map(|r| r.name.as_str()).collect();
560
561        // dp table: BTreeSet<usize> → (best plan, cost)
562        let mut dp: HashMap<BTreeSet<usize>, (JoinPlanNode, u64)> = HashMap::new();
563
564        // Base case: single relations.
565        for (i, r) in relations.iter().enumerate() {
566            let mut set = BTreeSet::new();
567            set.insert(i);
568            let node = JoinPlanNode::Scan {
569                relation: r.name.clone(),
570                estimated_cost: r.estimated_rows,
571            };
572            dp.insert(set, (node, r.estimated_rows));
573        }
574
575        // Enumerate subset sizes 2..=n
576        for size in 2..=n {
577            let subsets = subsets_of_size(n, size);
578            for subset in &subsets {
579                let mut best: Option<(JoinPlanNode, u64)> = None;
580
581                // Try all non-empty proper subsets s1 of subset.
582                // We enumerate s1 as subsets of `subset` with size 1..size-1.
583                let elems: Vec<usize> = subset.iter().copied().collect();
584                let m = elems.len();
585
586                for s1_size in 1..m {
587                    let s1_subsets = subsets_of_size(m, s1_size);
588                    for s1_indices in &s1_subsets {
589                        let s1: BTreeSet<usize> =
590                            s1_indices.iter().map(|&idx| elems[idx]).collect();
591                        let s2: BTreeSet<usize> = subset.difference(&s1).copied().collect();
592
593                        if s2.is_empty() {
594                            continue;
595                        }
596
597                        let (left_plan, _left_cost) = match dp.get(&s1) {
598                            Some(v) => v,
599                            None => continue,
600                        };
601                        let (right_plan, _right_cost) = match dp.get(&s2) {
602                            Some(v) => v,
603                            None => continue,
604                        };
605
606                        // Find join conditions between s1 and s2
607                        let left_names: HashSet<String> =
608                            s1.iter().map(|&i| idx_to_name[i].to_string()).collect();
609                        let right_names: HashSet<String> =
610                            s2.iter().map(|&i| idx_to_name[i].to_string()).collect();
611                        let conds = Self::find_conditions(&left_names, &right_names, conditions);
612
613                        let (cost, rows) = self.estimate_join_cost(left_plan, right_plan, &conds);
614
615                        if best.as_ref().is_none_or(|(_, bc)| cost < *bc) {
616                            let node = self.make_join_node(
617                                left_plan.clone(),
618                                right_plan.clone(),
619                                conds,
620                                cost,
621                                rows,
622                            );
623                            best = Some((node, cost));
624                        }
625                    }
626                }
627
628                if let Some(entry) = best {
629                    dp.insert(subset.clone(), entry);
630                }
631            }
632        }
633
634        // Retrieve full set.
635        let full: BTreeSet<usize> = (0..n).collect();
636        dp.remove(&full).map(|(node, _)| node).ok_or_else(|| {
637            JoinOrderError::DisconnectedGraph(
638                "could not find a plan covering all relations".to_string(),
639            )
640        })
641    }
642
643    /// Estimate the cost of joining two sub-plans.
644    fn estimate_join_cost(
645        &self,
646        left: &JoinPlanNode,
647        right: &JoinPlanNode,
648        conditions: &[JoinCondition],
649    ) -> (u64, u64) {
650        let left_rows = left.estimated_output_rows().max(1);
651        let right_rows = right.estimated_output_rows().max(1);
652
653        let selectivity = estimate_selectivity(
654            left_rows,
655            right_rows,
656            conditions.len(),
657            self.config.default_selectivity,
658        );
659
660        let output_rows =
661            ((left_rows as f64 * right_rows as f64 * selectivity).ceil() as u64).max(1);
662
663        let use_hash = right_rows > self.config.hash_join_threshold;
664        let join_cost = if use_hash {
665            // hash join: build + probe ≈ left + right + output
666            left_rows + right_rows + output_rows
667        } else {
668            // nested loop: left * right (but at least left + right)
669            (left_rows.saturating_mul(right_rows)).max(left_rows + right_rows)
670        };
671
672        let total_cost = left
673            .cost()
674            .saturating_add(right.cost())
675            .saturating_add(join_cost);
676        (total_cost, output_rows)
677    }
678
679    /// Find applicable join conditions between two sets of relations.
680    fn find_conditions(
681        left_rels: &HashSet<String>,
682        right_rels: &HashSet<String>,
683        all_conditions: &[JoinCondition],
684    ) -> Vec<JoinCondition> {
685        all_conditions
686            .iter()
687            .filter(|c| {
688                (left_rels.contains(&c.left_relation) && right_rels.contains(&c.right_relation))
689                    || (left_rels.contains(&c.right_relation)
690                        && right_rels.contains(&c.left_relation))
691            })
692            .cloned()
693            .collect()
694    }
695
696    /// Construct the appropriate join node based on config.
697    fn make_join_node(
698        &self,
699        left: JoinPlanNode,
700        right: JoinPlanNode,
701        conditions: Vec<JoinCondition>,
702        estimated_cost: u64,
703        estimated_rows: u64,
704    ) -> JoinPlanNode {
705        let right_rows = right.estimated_output_rows();
706        let use_hash = right_rows > self.config.hash_join_threshold;
707
708        let (left, right) = if self.config.prefer_small_left && use_hash {
709            if left.estimated_output_rows() > right.estimated_output_rows() {
710                (right, left)
711            } else {
712                (left, right)
713            }
714        } else {
715            (left, right)
716        };
717
718        if use_hash {
719            JoinPlanNode::HashJoin {
720                left: Box::new(left),
721                right: Box::new(right),
722                conditions,
723                estimated_cost,
724                estimated_rows,
725            }
726        } else {
727            JoinPlanNode::NestedLoopJoin {
728                left: Box::new(left),
729                right: Box::new(right),
730                conditions,
731                estimated_cost,
732                estimated_rows,
733            }
734        }
735    }
736}
737
738impl Default for JoinOrderOptimizer {
739    fn default() -> Self {
740        Self::with_default()
741    }
742}
743
744// ---------------------------------------------------------------------------
745// Tests
746// ---------------------------------------------------------------------------
747
748#[cfg(test)]
749mod tests {
750    use super::*;
751
752    #[test]
753    fn test_relation_new() {
754        let r = Relation::new("users", 3, 1000);
755        assert_eq!(r.name, "users");
756        assert_eq!(r.arity, 3);
757        assert_eq!(r.estimated_rows, 1000);
758        assert!(r.bound_columns.is_empty());
759    }
760
761    #[test]
762    fn test_relation_with_binding() {
763        let r = Relation::new("users", 3, 1000)
764            .with_binding(0)
765            .with_binding(2);
766        assert!(r.bound_columns.contains(&0));
767        assert!(r.bound_columns.contains(&2));
768        assert!(!r.bound_columns.contains(&1));
769        assert_eq!(r.bound_columns.len(), 2);
770    }
771
772    #[test]
773    fn test_relation_selectivity() {
774        let r = Relation::new("users", 4, 1000)
775            .with_binding(0)
776            .with_binding(1);
777        let sel = r.selectivity();
778        assert!((sel - 0.5).abs() < 1e-10);
779
780        let r_zero = Relation::new("empty", 0, 0);
781        assert!((r_zero.selectivity() - 0.0).abs() < 1e-10);
782    }
783
784    #[test]
785    fn test_join_config_default() {
786        let cfg = JoinOptimizerConfig::default();
787        assert_eq!(cfg.max_relations, 10);
788        assert_eq!(cfg.hash_join_threshold, 100);
789        assert!((cfg.default_selectivity - 0.1).abs() < 1e-10);
790        assert!(cfg.prefer_small_left);
791    }
792
793    #[test]
794    fn test_greedy_single_relation() {
795        let opt = JoinOrderOptimizer::with_default();
796        let rels = vec![Relation::new("users", 3, 100)];
797        let plan = opt.optimize(&rels, &[]).expect("should succeed");
798        assert!(matches!(plan.root, JoinPlanNode::Scan { .. }));
799        assert_eq!(plan.stats.relations_scanned, 1);
800        assert_eq!(plan.stats.joins_performed, 0);
801    }
802
803    #[test]
804    fn test_greedy_two_relations() {
805        let opt = JoinOrderOptimizer::with_default();
806        let rels = vec![
807            Relation::new("users", 2, 500),
808            Relation::new("orders", 3, 2000),
809        ];
810        let conds = vec![JoinCondition {
811            left_relation: "users".to_string(),
812            left_column: 0,
813            right_relation: "orders".to_string(),
814            right_column: 1,
815        }];
816        let plan = opt.optimize(&rels, &conds).expect("should succeed");
817        assert_eq!(plan.stats.relations_scanned, 2);
818        assert_eq!(plan.stats.joins_performed, 1);
819        assert!(plan.root.cost() > 0);
820    }
821
822    #[test]
823    fn test_greedy_three_relations() {
824        let opt = JoinOrderOptimizer::with_default();
825        let rels = vec![
826            Relation::new("a", 2, 100),
827            Relation::new("b", 2, 200),
828            Relation::new("c", 2, 300),
829        ];
830        let conds = vec![
831            JoinCondition {
832                left_relation: "a".to_string(),
833                left_column: 0,
834                right_relation: "b".to_string(),
835                right_column: 0,
836            },
837            JoinCondition {
838                left_relation: "b".to_string(),
839                left_column: 1,
840                right_relation: "c".to_string(),
841                right_column: 0,
842            },
843        ];
844        let plan = opt.optimize(&rels, &conds).expect("should succeed");
845        assert_eq!(plan.stats.relations_scanned, 3);
846        assert_eq!(plan.stats.joins_performed, 2);
847        assert!(plan.root.depth() >= 2);
848    }
849
850    #[test]
851    fn test_dp_two_relations() {
852        let opt = JoinOrderOptimizer::with_default();
853        let rels = vec![Relation::new("x", 2, 50), Relation::new("y", 2, 80)];
854        let conds = vec![JoinCondition {
855            left_relation: "x".to_string(),
856            left_column: 0,
857            right_relation: "y".to_string(),
858            right_column: 0,
859        }];
860        let plan = opt.optimize(&rels, &conds).expect("should succeed");
861        assert_eq!(plan.stats.relations_scanned, 2);
862        assert_eq!(plan.stats.joins_performed, 1);
863    }
864
865    #[test]
866    fn test_dp_three_relations() {
867        let opt = JoinOrderOptimizer::with_default();
868        let rels = vec![
869            Relation::new("r1", 2, 10),
870            Relation::new("r2", 2, 20),
871            Relation::new("r3", 2, 30),
872        ];
873        let conds = vec![
874            JoinCondition {
875                left_relation: "r1".to_string(),
876                left_column: 0,
877                right_relation: "r2".to_string(),
878                right_column: 0,
879            },
880            JoinCondition {
881                left_relation: "r2".to_string(),
882                left_column: 1,
883                right_relation: "r3".to_string(),
884                right_column: 0,
885            },
886        ];
887        let plan = opt.optimize(&rels, &conds).expect("should succeed");
888        assert_eq!(plan.stats.relations_scanned, 3);
889        assert_eq!(plan.stats.joins_performed, 2);
890        assert!(plan.root.depth() >= 2);
891    }
892
893    #[test]
894    fn test_optimize_uses_greedy_when_too_many() {
895        let cfg = JoinOptimizerConfig {
896            max_relations: 2,
897            ..Default::default()
898        };
899        let opt = JoinOrderOptimizer::new(cfg);
900        let rels = vec![
901            Relation::new("a", 2, 10),
902            Relation::new("b", 2, 20),
903            Relation::new("c", 2, 30),
904        ];
905        let conds = vec![
906            JoinCondition {
907                left_relation: "a".to_string(),
908                left_column: 0,
909                right_relation: "b".to_string(),
910                right_column: 0,
911            },
912            JoinCondition {
913                left_relation: "b".to_string(),
914                left_column: 1,
915                right_relation: "c".to_string(),
916                right_column: 0,
917            },
918        ];
919        // Should succeed using greedy fallback (3 > max_relations=2)
920        let plan = opt.optimize(&rels, &conds).expect("greedy fallback");
921        assert_eq!(plan.stats.relations_scanned, 3);
922    }
923
924    #[test]
925    fn test_optimize_no_relations_error() {
926        let opt = JoinOrderOptimizer::with_default();
927        let result = opt.optimize(&[], &[]);
928        assert!(result.is_err());
929        assert!(matches!(result, Err(JoinOrderError::NoRelations)));
930    }
931
932    #[test]
933    fn test_join_plan_node_cost() {
934        let node = JoinPlanNode::Scan {
935            relation: "t".to_string(),
936            estimated_cost: 42,
937        };
938        assert_eq!(node.cost(), 42);
939        assert!(node.cost() > 0);
940    }
941
942    #[test]
943    fn test_join_plan_node_depth() {
944        let leaf = JoinPlanNode::Scan {
945            relation: "t".to_string(),
946            estimated_cost: 10,
947        };
948        assert_eq!(leaf.depth(), 1);
949
950        let join = JoinPlanNode::HashJoin {
951            left: Box::new(JoinPlanNode::Scan {
952                relation: "a".to_string(),
953                estimated_cost: 5,
954            }),
955            right: Box::new(JoinPlanNode::Scan {
956                relation: "b".to_string(),
957                estimated_cost: 10,
958            }),
959            conditions: vec![],
960            estimated_cost: 20,
961            estimated_rows: 8,
962        };
963        assert_eq!(join.depth(), 2);
964    }
965
966    #[test]
967    fn test_join_plan_node_relations() {
968        let join = JoinPlanNode::HashJoin {
969            left: Box::new(JoinPlanNode::Scan {
970                relation: "a".to_string(),
971                estimated_cost: 5,
972            }),
973            right: Box::new(JoinPlanNode::Scan {
974                relation: "b".to_string(),
975                estimated_cost: 10,
976            }),
977            conditions: vec![],
978            estimated_cost: 20,
979            estimated_rows: 8,
980        };
981        let rels = join.relations_involved();
982        assert!(rels.contains(&"a".to_string()));
983        assert!(rels.contains(&"b".to_string()));
984        assert_eq!(rels.len(), 2);
985    }
986
987    #[test]
988    fn test_join_plan_format_tree() {
989        let opt = JoinOrderOptimizer::with_default();
990        let rels = vec![Relation::new("a", 2, 100), Relation::new("b", 2, 200)];
991        let conds = vec![JoinCondition {
992            left_relation: "a".to_string(),
993            left_column: 0,
994            right_relation: "b".to_string(),
995            right_column: 0,
996        }];
997        let plan = opt.optimize(&rels, &conds).expect("ok");
998        let tree = plan.format_tree();
999        assert!(!tree.is_empty());
1000    }
1001
1002    #[test]
1003    fn test_join_plan_format_dot() {
1004        let opt = JoinOrderOptimizer::with_default();
1005        let rels = vec![Relation::new("a", 2, 100), Relation::new("b", 2, 200)];
1006        let conds = vec![JoinCondition {
1007            left_relation: "a".to_string(),
1008            left_column: 0,
1009            right_relation: "b".to_string(),
1010            right_column: 0,
1011        }];
1012        let plan = opt.optimize(&rels, &conds).expect("ok");
1013        let dot = plan.format_dot();
1014        assert!(dot.contains("digraph"));
1015    }
1016
1017    #[test]
1018    fn test_estimate_selectivity() {
1019        let sel = estimate_selectivity(1000, 2000, 1, 0.1);
1020        assert!(sel > 0.0);
1021        assert!(sel <= 1.0);
1022
1023        // No conditions → cross product selectivity = 1.0
1024        let sel_cross = estimate_selectivity(100, 100, 0, 0.1);
1025        assert!((sel_cross - 1.0).abs() < 1e-10);
1026
1027        // Multiple conditions → smaller selectivity
1028        let sel_one = estimate_selectivity(100, 200, 1, 0.1);
1029        let sel_two = estimate_selectivity(100, 200, 2, 0.1);
1030        assert!(sel_two < sel_one);
1031    }
1032
1033    #[test]
1034    fn test_find_conditions() {
1035        let conds = vec![
1036            JoinCondition {
1037                left_relation: "a".to_string(),
1038                left_column: 0,
1039                right_relation: "b".to_string(),
1040                right_column: 0,
1041            },
1042            JoinCondition {
1043                left_relation: "b".to_string(),
1044                left_column: 1,
1045                right_relation: "c".to_string(),
1046                right_column: 0,
1047            },
1048        ];
1049
1050        let left: HashSet<String> = ["a".to_string()].into_iter().collect();
1051        let right: HashSet<String> = ["b".to_string()].into_iter().collect();
1052        let found = JoinOrderOptimizer::find_conditions(&left, &right, &conds);
1053        assert_eq!(found.len(), 1);
1054        assert_eq!(found[0].left_relation, "a");
1055
1056        let left2: HashSet<String> = ["a".to_string()].into_iter().collect();
1057        let right2: HashSet<String> = ["c".to_string()].into_iter().collect();
1058        let found2 = JoinOrderOptimizer::find_conditions(&left2, &right2, &conds);
1059        assert_eq!(found2.len(), 0);
1060    }
1061
1062    #[test]
1063    fn test_join_stats() {
1064        let opt = JoinOrderOptimizer::with_default();
1065        let rels = vec![
1066            Relation::new("a", 2, 100),
1067            Relation::new("b", 2, 200),
1068            Relation::new("c", 2, 300),
1069        ];
1070        let conds = vec![
1071            JoinCondition {
1072                left_relation: "a".to_string(),
1073                left_column: 0,
1074                right_relation: "b".to_string(),
1075                right_column: 0,
1076            },
1077            JoinCondition {
1078                left_relation: "b".to_string(),
1079                left_column: 1,
1080                right_relation: "c".to_string(),
1081                right_column: 0,
1082            },
1083        ];
1084        let plan = opt.optimize(&rels, &conds).expect("ok");
1085        assert_eq!(plan.stats.relations_scanned, 3);
1086        assert_eq!(plan.stats.joins_performed, 2);
1087        assert!(plan.stats.total_estimated_cost > 0);
1088        assert!(plan.stats.total_estimated_rows > 0);
1089        assert!(plan.stats.plan_depth >= 2);
1090    }
1091
1092    #[test]
1093    fn test_join_order_error_display() {
1094        let e1 = JoinOrderError::NoRelations;
1095        assert!(!e1.to_string().is_empty());
1096
1097        let e2 = JoinOrderError::DisconnectedGraph("parts missing".to_string());
1098        assert!(e2.to_string().contains("disconnected"));
1099
1100        let e3 = JoinOrderError::TooManyRelations { count: 20, max: 10 };
1101        assert!(e3.to_string().contains("20"));
1102
1103        let e4 = JoinOrderError::InvalidCondition("bad ref".to_string());
1104        assert!(e4.to_string().contains("invalid"));
1105    }
1106}