vibesql_executor/select/join/
reorder.rs

1//! Join order optimization using selectivity-based heuristics
2//!
3//! This module analyzes predicates to determine optimal join orderings.
4//! The goal is to maximize row reduction early in the join chain to minimize
5//! intermediate result size during cascade joins.
6
7//! ## Example
8//!
9//! ```text
10//! SELECT * FROM t1, t2, t3, t4, t5, t6, t7, t8, t9, t10
11//! WHERE a1 = 5                    -- local predicate for t1
12//!   AND a1 = b2                   -- equijoin t1-t2
13//!   AND a2 = b3                   -- equijoin (t1 ∪ t2)-t3
14//!   AND a3 = b4                   -- equijoin (t1 ∪ t2 ∪ t3)-t4
15//!   ...
16//!
17//! Default order (left-to-right cascade):
18//! ((((((((((t1 JOIN t2) JOIN t3) JOIN t4) ... JOIN t10)
19//!
20//! Result: 90-row intermediates at each step (9 rows × 10 rows before equijoin filter)
21//!
22//! Optimal order (with selectivity awareness):
23//! Start with most selective: t1 (filtered to ~1 row by a1=5)
24//! Then: t1 JOIN t2 (1 × 10 = 10 intermediate, filtered to ~1 by a1=b2)
25//! Then: result JOIN t3 (1 × 10 = 10 intermediate, filtered to ~1 by a2=b3)
26//! ...
27//!
28//! Result: 10-row intermediates maximum (much better memory usage)
29//! ```
30
31use std::{
32    cmp::Ordering,
33    collections::{HashMap, HashSet},
34};
35
36use vibesql_ast::{BinaryOperator, Expression};
37
38/// Information about a table and its predicates
39#[derive(Debug, Clone)]
40struct TableInfo {
41    name: String,
42    local_predicates: Vec<Expression>, // Predicates that only reference this table
43    local_selectivity: f64,            // Estimated selectivity of local predicates (0.0-1.0)
44}
45
46/// Information about an equijoin between two tables
47#[derive(Debug, Clone, PartialEq)]
48pub struct JoinEdge {
49    /// Table name on left side of equijoin
50    pub left_table: String,
51    /// Column from left table
52    pub left_column: String,
53    /// Table name on right side of equijoin
54    pub right_table: String,
55    /// Column from right table
56    pub right_column: String,
57    /// Join type (INNER, SEMI, ANTI, etc.)
58    pub join_type: vibesql_ast::JoinType,
59}
60
61impl JoinEdge {
62    /// Check if this edge involves a specific table
63    pub fn involves_table(&self, table: &str) -> bool {
64        self.left_table.eq_ignore_ascii_case(table) || self.right_table.eq_ignore_ascii_case(table)
65    }
66
67    /// Get the other table in this edge (if input is one side)
68    pub fn other_table(&self, table: &str) -> Option<String> {
69        if self.left_table.eq_ignore_ascii_case(table) {
70            Some(self.right_table.clone())
71        } else if self.right_table.eq_ignore_ascii_case(table) {
72            Some(self.left_table.clone())
73        } else {
74            None
75        }
76    }
77}
78
79/// Selectivity information for a predicate
80#[derive(Debug, Clone)]
81pub struct Selectivity {
82    /// Estimated selectivity (0.0 = filters everything, 1.0 = no filtering)
83    pub factor: f64,
84    /// Type of selectivity (local vs equijoin)
85    pub predicate_type: PredicateType,
86}
87
88/// Classification of predicates by type
89#[derive(Debug, Clone, PartialEq, Eq)]
90pub enum PredicateType {
91    /// Predicate on single table (e.g., a1 > 5)
92    Local,
93    /// Equijoin between two tables (e.g., a1 = b2)
94    Equijoin,
95    /// Complex predicate involving multiple tables
96    Complex,
97}
98
99/// Analyzes join chains and determines optimal join ordering
100#[derive(Debug, Clone)]
101pub struct JoinOrderAnalyzer {
102    /// Mapping from table name to table info
103    tables: HashMap<String, TableInfo>,
104    /// List of equijoin edges discovered
105    edges: Vec<JoinEdge>,
106    /// Selectivity information for each predicate
107    #[allow(dead_code)]
108    selectivity: HashMap<String, Selectivity>,
109    /// Schema-based column-to-table mapping for resolving unqualified columns
110    column_to_table: HashMap<String, String>,
111}
112
113impl Default for JoinOrderAnalyzer {
114    fn default() -> Self {
115        Self::new()
116    }
117}
118
119impl JoinOrderAnalyzer {
120    /// Create a new join order analyzer
121    pub fn new() -> Self {
122        Self {
123            tables: HashMap::new(),
124            edges: Vec::new(),
125            selectivity: HashMap::new(),
126            column_to_table: HashMap::new(),
127        }
128    }
129
130    /// Create a new join order analyzer with schema-based column resolution
131    pub fn with_column_map(column_to_table: HashMap<String, String>) -> Self {
132        Self {
133            tables: HashMap::new(),
134            edges: Vec::new(),
135            selectivity: HashMap::new(),
136            column_to_table,
137        }
138    }
139
140    /// Set the column-to-table mapping for schema-based resolution
141    pub fn set_column_map(&mut self, column_to_table: HashMap<String, String>) {
142        self.column_to_table = column_to_table;
143    }
144
145    /// Register all tables involved in the query
146    pub fn register_tables(&mut self, table_names: Vec<String>) {
147        for name in table_names {
148            self.tables.insert(
149                name.to_lowercase(),
150                TableInfo {
151                    name: name.to_lowercase(),
152                    local_predicates: Vec::new(),
153                    local_selectivity: 1.0,
154                },
155            );
156        }
157    }
158
159    /// Analyze a predicate and extract join edges or local predicates
160    pub fn analyze_predicate(&mut self, expr: &Expression, tables: &HashSet<String>) {
161        self.analyze_predicate_with_type(expr, tables, vibesql_ast::JoinType::Inner);
162    }
163
164    /// Analyze a predicate with an explicit join type
165    pub fn analyze_predicate_with_type(
166        &mut self,
167        expr: &Expression,
168        tables: &HashSet<String>,
169        join_type: vibesql_ast::JoinType,
170    ) {
171        match expr {
172            // Recursively handle AND expressions
173            Expression::BinaryOp { op: BinaryOperator::And, left, right } => {
174                if std::env::var("JOIN_REORDER_VERBOSE").is_ok() {
175                    eprintln!("[ANALYZER] Decomposing AND expression");
176                }
177                self.analyze_predicate_with_type(left, tables, join_type.clone());
178                self.analyze_predicate_with_type(right, tables, join_type);
179            }
180            // Handle OR expressions by extracting common join conditions
181            // that appear in ALL branches
182            Expression::BinaryOp { op: BinaryOperator::Or, .. } => {
183                if std::env::var("JOIN_REORDER_VERBOSE").is_ok() {
184                    eprintln!("[ANALYZER] Analyzing OR expression for common join conditions");
185                }
186
187                // Collect all OR branches into a list
188                let mut branches = Vec::new();
189                self.collect_or_branches(expr, &mut branches);
190
191                if std::env::var("JOIN_REORDER_VERBOSE").is_ok() {
192                    eprintln!("[ANALYZER] Found {} OR branches", branches.len());
193                }
194
195                // Extract edges from each branch
196                let mut branch_edges: Vec<Vec<JoinEdge>> = Vec::new();
197                for branch in &branches {
198                    // IMPORTANT: Pass the column_to_table map to branch analyzers
199                    // so they can resolve column names to tables
200                    let mut branch_analyzer =
201                        JoinOrderAnalyzer::with_column_map(self.column_to_table.clone());
202                    let table_vec: Vec<String> = tables.iter().cloned().collect();
203                    branch_analyzer.register_tables(table_vec);
204                    branch_analyzer.analyze_predicate(branch, tables);
205                    branch_edges.push(branch_analyzer.edges().to_vec());
206                }
207
208                // Find edges common to ALL branches
209                if !branch_edges.is_empty() {
210                    let common_edges = self.find_common_edges(&branch_edges);
211
212                    if std::env::var("JOIN_REORDER_VERBOSE").is_ok() {
213                        eprintln!(
214                            "[ANALYZER] Found {} common join edges across all OR branches",
215                            common_edges.len()
216                        );
217                    }
218
219                    // Add common edges to our join graph
220                    for edge in common_edges {
221                        if std::env::var("JOIN_REORDER_VERBOSE").is_ok() {
222                            eprintln!(
223                                "[ANALYZER] Added common edge from OR: {}.{} = {}.{}",
224                                edge.left_table,
225                                edge.left_column,
226                                edge.right_table,
227                                edge.right_column
228                            );
229                        }
230                        self.edges.push(edge);
231                    }
232                }
233            }
234            // Handle simple binary equality operations
235            Expression::BinaryOp { op: BinaryOperator::Equal, left, right } => {
236                let (left_table, left_col) = self.extract_column_ref(left, tables);
237                let (right_table, right_col) = self.extract_column_ref(right, tables);
238
239                if std::env::var("JOIN_REORDER_VERBOSE").is_ok() {
240                    eprintln!("[ANALYZER] analyze_predicate: left_table={:?}, right_table={:?}, left_col={:?}, right_col={:?}",
241                        left_table, right_table, left_col, right_col);
242                }
243
244                match (left_table, right_table, left_col, right_col) {
245                    // Equijoin: column from one table = column from another
246                    (Some(lt), Some(rt), Some(lc), Some(rc)) if lt != rt => {
247                        let edge = JoinEdge {
248                            left_table: lt.to_lowercase(),
249                            left_column: lc.clone(),
250                            right_table: rt.to_lowercase(),
251                            right_column: rc.clone(),
252                            join_type: join_type.clone(),
253                        };
254                        if std::env::var("JOIN_REORDER_VERBOSE").is_ok() {
255                            eprintln!(
256                                "[ANALYZER] Added edge: {}.{} = {}.{} (join_type: {:?})",
257                                lt, lc, rt, rc, join_type
258                            );
259                        }
260                        self.edges.push(edge);
261                    }
262                    // Local predicate: column = constant
263                    (Some(table), None, Some(_col), _) => {
264                        if let Some(table_info) = self.tables.get_mut(&table.to_lowercase()) {
265                            table_info.local_predicates.push(expr.clone());
266                            // Heuristic: equality predicate has ~10% selectivity
267                            table_info.local_selectivity *= 0.1;
268                        }
269                    }
270                    _ => {
271                        if std::env::var("JOIN_REORDER_VERBOSE").is_ok() {
272                            eprintln!("[ANALYZER] Skipped predicate (no match)");
273                        }
274                    }
275                }
276            }
277            // For other operators, analyze for local vs cross-table
278            _ => {
279                // Conservative: mark as complex, don't try to optimize
280            }
281        }
282    }
283
284    /// Collect all branches of an OR expression into a flat list
285    /// Handles nested ORs by flattening them: (A OR B) OR C => [A, B, C]
286    #[allow(clippy::only_used_in_recursion)]
287    fn collect_or_branches(&self, expr: &Expression, branches: &mut Vec<Expression>) {
288        match expr {
289            Expression::BinaryOp { op: BinaryOperator::Or, left, right } => {
290                // Recursively collect from both sides
291                self.collect_or_branches(left, branches);
292                self.collect_or_branches(right, branches);
293            }
294            _ => {
295                // Leaf node - add to branches
296                branches.push(expr.clone());
297            }
298        }
299    }
300
301    /// Find join edges that appear in ALL branches
302    /// An edge is common if it has the same tables and columns in every branch
303    fn find_common_edges(&self, branch_edges: &[Vec<JoinEdge>]) -> Vec<JoinEdge> {
304        if branch_edges.is_empty() {
305            return Vec::new();
306        }
307
308        // Start with edges from first branch
309        let mut common_edges = Vec::new();
310        let first_branch = &branch_edges[0];
311
312        for edge in first_branch {
313            // Check if this edge appears in all other branches
314            let appears_in_all = branch_edges[1..]
315                .iter()
316                .all(|branch| branch.iter().any(|e| self.edges_match(e, edge)));
317
318            if appears_in_all {
319                common_edges.push(edge.clone());
320            }
321        }
322
323        common_edges
324    }
325
326    /// Check if two edges represent the same join condition
327    /// Handles both (A=B) and (B=A) as equivalent
328    fn edges_match(&self, e1: &JoinEdge, e2: &JoinEdge) -> bool {
329        // Direct match: left-to-left, right-to-right
330        let direct = e1.left_table.eq_ignore_ascii_case(&e2.left_table)
331            && e1.left_column.eq_ignore_ascii_case(&e2.left_column)
332            && e1.right_table.eq_ignore_ascii_case(&e2.right_table)
333            && e1.right_column.eq_ignore_ascii_case(&e2.right_column);
334
335        // Reverse match: left-to-right, right-to-left (handles A=B vs B=A)
336        let reverse = e1.left_table.eq_ignore_ascii_case(&e2.right_table)
337            && e1.left_column.eq_ignore_ascii_case(&e2.right_column)
338            && e1.right_table.eq_ignore_ascii_case(&e2.left_table)
339            && e1.right_column.eq_ignore_ascii_case(&e2.left_column);
340
341        direct || reverse
342    }
343
344    /// Extract table and column info from an expression
345    /// Returns (table_name, column_name)
346    /// Uses table inference if explicit table prefix is not present
347    fn extract_column_ref(
348        &self,
349        expr: &Expression,
350        tables: &HashSet<String>,
351    ) -> (Option<String>, Option<String>) {
352        match expr {
353            Expression::ColumnRef(col_id) => {
354                let column = col_id.column_canonical();
355                // If explicit table prefix exists, use it
356                if let Some(t) = col_id.table_canonical() {
357                    return (Some(t.to_string()), Some(column.to_string()));
358                }
359
360                // Otherwise, infer table from column prefix
361                let inferred_table = self.infer_table_from_column(column, tables);
362                (inferred_table, Some(column.to_string()))
363            }
364            Expression::Literal(_) => (None, None),
365            _ => (None, None),
366        }
367    }
368
369    /// Infer table name from column name using schema-based lookup
370    ///
371    /// Uses the schema-based column-to-table map to resolve column references.
372    /// All column resolution relies solely on actual database schema metadata.
373    fn infer_table_from_column(&self, column: &str, tables: &HashSet<String>) -> Option<String> {
374        // Schema-based lookup only - no heuristic fallbacks
375        if self.column_to_table.is_empty() {
376            if std::env::var("JOIN_REORDER_VERBOSE").is_ok() {
377                eprintln!(
378                    "[ANALYZER] Warning: No column-to-table map available for column {}",
379                    column
380                );
381            }
382            return None;
383        }
384
385        let col_lower = column.to_lowercase();
386        if let Some(table) = self.column_to_table.get(&col_lower) {
387            if std::env::var("JOIN_REORDER_VERBOSE").is_ok() {
388                eprintln!("[ANALYZER] Schema lookup: {} -> {}", column, table);
389            }
390            // Verify the table is in our set (could be aliased)
391            if tables.contains(table) {
392                return Some(table.clone());
393            }
394            // Try case-insensitive match
395            for t in tables {
396                if t.eq_ignore_ascii_case(table) {
397                    return Some(t.clone());
398                }
399            }
400            if std::env::var("JOIN_REORDER_VERBOSE").is_ok() {
401                eprintln!("[ANALYZER] Warning: Table {} not in tables set {:?}", table, tables);
402            }
403        } else if std::env::var("JOIN_REORDER_VERBOSE").is_ok() {
404            eprintln!(
405                "[ANALYZER] Warning: Column {} not found in schema map (available: {:?})",
406                col_lower,
407                self.column_to_table.keys().take(10).collect::<Vec<_>>()
408            );
409        }
410
411        None
412    }
413
414    /// Find all tables that have local predicates (highest selectivity filters)
415    pub fn find_most_selective_tables(&self) -> Vec<String> {
416        let mut tables: Vec<_> =
417            self.tables.values().filter(|t| !t.local_predicates.is_empty()).collect();
418
419        // Sort by selectivity (most selective first)
420        tables.sort_by(|a, b| {
421            a.local_selectivity.partial_cmp(&b.local_selectivity).unwrap_or(Ordering::Equal)
422        });
423
424        tables.iter().map(|t| t.name.clone()).collect()
425    }
426
427    /// Build a join chain starting from a seed table
428    /// Returns list of tables in optimal join order
429    pub fn build_join_chain(&self, seed_table: &str) -> Vec<String> {
430        let mut chain = vec![seed_table.to_lowercase()];
431        let mut visited = HashSet::new();
432        visited.insert(seed_table.to_lowercase());
433
434        // Greedy: follow edges from current table
435        while chain.len() < self.tables.len() {
436            let current_table = chain[chain.len() - 1].clone();
437
438            // Find an edge from current table
439            let mut next_table: Option<String> = None;
440            for edge in &self.edges {
441                if edge.left_table == current_table && !visited.contains(&edge.right_table) {
442                    next_table = Some(edge.right_table.clone());
443                    break;
444                } else if edge.right_table == current_table && !visited.contains(&edge.left_table) {
445                    next_table = Some(edge.left_table.clone());
446                    break;
447                }
448            }
449
450            // If no edge found, pick any unvisited table
451            if next_table.is_none() {
452                for table in self.tables.keys() {
453                    if !visited.contains(table) {
454                        next_table = Some(table.clone());
455                        break;
456                    }
457                }
458            }
459
460            if let Some(table) = next_table {
461                chain.push(table.clone());
462                visited.insert(table);
463            } else {
464                break;
465            }
466        }
467
468        chain
469    }
470
471    /// Find optimal join order given all constraints
472    ///
473    /// Uses heuristic: start with most selective local filters,
474    /// then follow equijoin chains
475    pub fn find_optimal_order(&self) -> Vec<String> {
476        // Find most selective tables (those with local predicates)
477        let selective_tables = self.find_most_selective_tables();
478
479        // Start with most selective, build chain
480        if let Some(seed) = selective_tables.first() {
481            self.build_join_chain(seed)
482        } else {
483            // Fallback: just use first table
484            if let Some(table) = self.tables.keys().next() {
485                self.build_join_chain(table)
486            } else {
487                Vec::new()
488            }
489        }
490    }
491
492    /// Get the equijoin edges that connect two specific tables
493    pub fn get_join_condition(
494        &self,
495        left_table: &str,
496        right_table: &str,
497    ) -> Option<(String, String)> {
498        let left_lower = left_table.to_lowercase();
499        let right_lower = right_table.to_lowercase();
500
501        for edge in &self.edges {
502            if (edge.left_table == left_lower && edge.right_table == right_lower)
503                || (edge.left_table == right_lower && edge.right_table == left_lower)
504            {
505                return Some((edge.left_column.clone(), edge.right_column.clone()));
506            }
507        }
508        None
509    }
510
511    /// Get all equijoin edges
512    pub fn edges(&self) -> &[JoinEdge] {
513        &self.edges
514    }
515
516    /// Get all tables registered in this analyzer
517    pub fn tables(&self) -> std::collections::BTreeSet<String> {
518        self.tables.keys().cloned().collect()
519    }
520
521    /// Add a join edge (for testing)
522    #[cfg(test)]
523    pub fn add_edge(&mut self, edge: JoinEdge) {
524        self.edges.push(edge);
525    }
526}
527
528#[cfg(test)]
529mod tests {
530    use super::*;
531
532    #[test]
533    fn test_join_edge_involvement() {
534        let edge = JoinEdge {
535            left_table: "t1".to_string(),
536            left_column: "a".to_string(),
537            right_table: "t2".to_string(),
538            right_column: "b".to_string(),
539            join_type: vibesql_ast::JoinType::Inner,
540        };
541
542        assert!(edge.involves_table("t1"));
543        assert!(edge.involves_table("t2"));
544        assert!(!edge.involves_table("t3"));
545    }
546
547    #[test]
548    fn test_join_edge_other_table() {
549        let edge = JoinEdge {
550            left_table: "t1".to_string(),
551            left_column: "a".to_string(),
552            right_table: "t2".to_string(),
553            right_column: "b".to_string(),
554            join_type: vibesql_ast::JoinType::Inner,
555        };
556
557        assert_eq!(edge.other_table("t1"), Some("t2".to_string()));
558        assert_eq!(edge.other_table("t2"), Some("t1".to_string()));
559        assert_eq!(edge.other_table("t3"), None);
560    }
561
562    #[test]
563    fn test_basic_chain_detection() {
564        let mut analyzer = JoinOrderAnalyzer::new();
565        analyzer.register_tables(vec!["t1".to_string(), "t2".to_string(), "t3".to_string()]);
566
567        // Add edges: t1-t2, t2-t3
568        analyzer.edges.push(JoinEdge {
569            left_table: "t1".to_string(),
570            left_column: "id".to_string(),
571            right_table: "t2".to_string(),
572            right_column: "id".to_string(),
573            join_type: vibesql_ast::JoinType::Inner,
574        });
575        analyzer.edges.push(JoinEdge {
576            left_table: "t2".to_string(),
577            left_column: "id".to_string(),
578            right_table: "t3".to_string(),
579            right_column: "id".to_string(),
580            join_type: vibesql_ast::JoinType::Inner,
581        });
582
583        let chain = analyzer.build_join_chain("t1");
584        assert_eq!(chain.len(), 3);
585        assert_eq!(chain[0], "t1");
586        // Should follow edges: t1 -> t2 -> t3
587    }
588
589    #[test]
590    fn test_most_selective_tables() {
591        let mut analyzer = JoinOrderAnalyzer::new();
592        analyzer.register_tables(vec!["t1".to_string(), "t2".to_string(), "t3".to_string()]);
593
594        // Create dummy predicates
595        let dummy_pred = Expression::Literal(vibesql_types::SqlValue::Integer(5));
596
597        // Add local predicates to t1 (most selective)
598        if let Some(table_info) = analyzer.tables.get_mut("t1") {
599            table_info.local_predicates.push(dummy_pred.clone());
600            table_info.local_selectivity = 0.1;
601        }
602
603        // Add local predicate to t2 (less selective)
604        if let Some(table_info) = analyzer.tables.get_mut("t2") {
605            table_info.local_predicates.push(dummy_pred.clone());
606            table_info.local_selectivity = 0.5;
607        }
608
609        let selective = analyzer.find_most_selective_tables();
610        assert_eq!(selective[0], "t1"); // Most selective first
611    }
612
613    #[test]
614    fn test_join_condition_lookup() {
615        let mut analyzer = JoinOrderAnalyzer::new();
616        analyzer.register_tables(vec!["t1".to_string(), "t2".to_string()]);
617
618        analyzer.edges.push(JoinEdge {
619            left_table: "t1".to_string(),
620            left_column: "id".to_string(),
621            right_table: "t2".to_string(),
622            right_column: "id".to_string(),
623            join_type: vibesql_ast::JoinType::Inner,
624        });
625
626        let condition = analyzer.get_join_condition("t1", "t2");
627        assert!(condition.is_some());
628        assert_eq!(condition.unwrap(), ("id".to_string(), "id".to_string()));
629    }
630
631    #[test]
632    fn test_case_insensitive_tables() {
633        let mut analyzer = JoinOrderAnalyzer::new();
634        analyzer.register_tables(vec!["T1".to_string(), "T2".to_string()]);
635
636        analyzer.edges.push(JoinEdge {
637            left_table: "t1".to_string(),
638            left_column: "id".to_string(),
639            right_table: "t2".to_string(),
640            right_column: "id".to_string(),
641            join_type: vibesql_ast::JoinType::Inner,
642        });
643
644        // Should find condition even with case differences
645        let condition = analyzer.get_join_condition("T1", "T2");
646        assert!(condition.is_some());
647    }
648}