vibesql_executor/select/join/
reorder.rs

1//! Join order optimization using selectivity-based heuristics
2//!
3//! This module analyzes predicates to determine optimal join orderings.
4//! The goal is to maximize row reduction early in the join chain to minimize
5//! intermediate result size during cascade joins.
6
7//! ## Example
8//!
9//! ```text
10//! SELECT * FROM t1, t2, t3, t4, t5, t6, t7, t8, t9, t10
11//! WHERE a1 = 5                    -- local predicate for t1
12//!   AND a1 = b2                   -- equijoin t1-t2
13//!   AND a2 = b3                   -- equijoin (t1 ∪ t2)-t3
14//!   AND a3 = b4                   -- equijoin (t1 ∪ t2 ∪ t3)-t4
15//!   ...
16//!
17//! Default order (left-to-right cascade):
18//! ((((((((((t1 JOIN t2) JOIN t3) JOIN t4) ... JOIN t10)
19//!
20//! Result: 90-row intermediates at each step (9 rows × 10 rows before equijoin filter)
21//!
22//! Optimal order (with selectivity awareness):
23//! Start with most selective: t1 (filtered to ~1 row by a1=5)
24//! Then: t1 JOIN t2 (1 × 10 = 10 intermediate, filtered to ~1 by a1=b2)
25//! Then: result JOIN t3 (1 × 10 = 10 intermediate, filtered to ~1 by a2=b3)
26//! ...
27//!
28//! Result: 10-row intermediates maximum (much better memory usage)
29//! ```
30
31use std::{
32    cmp::Ordering,
33    collections::{HashMap, HashSet},
34};
35
36use vibesql_ast::{BinaryOperator, Expression};
37
38/// Information about a table and its predicates
39#[derive(Debug, Clone)]
40struct TableInfo {
41    name: String,
42    local_predicates: Vec<Expression>, // Predicates that only reference this table
43    local_selectivity: f64,            // Estimated selectivity of local predicates (0.0-1.0)
44}
45
46/// Information about an equijoin between two tables
47#[derive(Debug, Clone, PartialEq)]
48pub struct JoinEdge {
49    /// Table name on left side of equijoin
50    pub left_table: String,
51    /// Column from left table
52    pub left_column: String,
53    /// Table name on right side of equijoin
54    pub right_table: String,
55    /// Column from right table
56    pub right_column: String,
57    /// Join type (INNER, SEMI, ANTI, etc.)
58    pub join_type: vibesql_ast::JoinType,
59}
60
61impl JoinEdge {
62    /// Check if this edge involves a specific table
63    pub fn involves_table(&self, table: &str) -> bool {
64        self.left_table.eq_ignore_ascii_case(table) || self.right_table.eq_ignore_ascii_case(table)
65    }
66
67    /// Get the other table in this edge (if input is one side)
68    pub fn other_table(&self, table: &str) -> Option<String> {
69        if self.left_table.eq_ignore_ascii_case(table) {
70            Some(self.right_table.clone())
71        } else if self.right_table.eq_ignore_ascii_case(table) {
72            Some(self.left_table.clone())
73        } else {
74            None
75        }
76    }
77}
78
79/// Selectivity information for a predicate
80#[derive(Debug, Clone)]
81pub struct Selectivity {
82    /// Estimated selectivity (0.0 = filters everything, 1.0 = no filtering)
83    pub factor: f64,
84    /// Type of selectivity (local vs equijoin)
85    pub predicate_type: PredicateType,
86}
87
88/// Classification of predicates by type
89#[derive(Debug, Clone, PartialEq, Eq)]
90pub enum PredicateType {
91    /// Predicate on single table (e.g., a1 > 5)
92    Local,
93    /// Equijoin between two tables (e.g., a1 = b2)
94    Equijoin,
95    /// Complex predicate involving multiple tables
96    Complex,
97}
98
99/// Analyzes join chains and determines optimal join ordering
100#[derive(Debug, Clone)]
101pub struct JoinOrderAnalyzer {
102    /// Mapping from table name to table info
103    tables: HashMap<String, TableInfo>,
104    /// List of equijoin edges discovered
105    edges: Vec<JoinEdge>,
106    /// Selectivity information for each predicate
107    #[allow(dead_code)]
108    selectivity: HashMap<String, Selectivity>,
109    /// Schema-based column-to-table mapping for resolving unqualified columns
110    column_to_table: HashMap<String, String>,
111}
112
113impl Default for JoinOrderAnalyzer {
114    fn default() -> Self {
115        Self::new()
116    }
117}
118
119impl JoinOrderAnalyzer {
120    /// Create a new join order analyzer
121    pub fn new() -> Self {
122        Self {
123            tables: HashMap::new(),
124            edges: Vec::new(),
125            selectivity: HashMap::new(),
126            column_to_table: HashMap::new(),
127        }
128    }
129
130    /// Create a new join order analyzer with schema-based column resolution
131    pub fn with_column_map(column_to_table: HashMap<String, String>) -> Self {
132        Self {
133            tables: HashMap::new(),
134            edges: Vec::new(),
135            selectivity: HashMap::new(),
136            column_to_table,
137        }
138    }
139
140    /// Set the column-to-table mapping for schema-based resolution
141    pub fn set_column_map(&mut self, column_to_table: HashMap<String, String>) {
142        self.column_to_table = column_to_table;
143    }
144
145    /// Register all tables involved in the query
146    pub fn register_tables(&mut self, table_names: Vec<String>) {
147        for name in table_names {
148            self.tables.insert(
149                name.to_lowercase(),
150                TableInfo {
151                    name: name.to_lowercase(),
152                    local_predicates: Vec::new(),
153                    local_selectivity: 1.0,
154                },
155            );
156        }
157    }
158
159    /// Analyze a predicate and extract join edges or local predicates
160    pub fn analyze_predicate(&mut self, expr: &Expression, tables: &HashSet<String>) {
161        self.analyze_predicate_with_type(expr, tables, vibesql_ast::JoinType::Inner);
162    }
163
164    /// Analyze a predicate with an explicit join type
165    pub fn analyze_predicate_with_type(&mut self, expr: &Expression, tables: &HashSet<String>, join_type: vibesql_ast::JoinType) {
166        match expr {
167            // Recursively handle AND expressions
168            Expression::BinaryOp { op: BinaryOperator::And, left, right } => {
169                if std::env::var("JOIN_REORDER_VERBOSE").is_ok() {
170                    eprintln!("[ANALYZER] Decomposing AND expression");
171                }
172                self.analyze_predicate_with_type(left, tables, join_type.clone());
173                self.analyze_predicate_with_type(right, tables, join_type);
174            }
175            // Handle OR expressions by extracting common join conditions
176            // that appear in ALL branches
177            Expression::BinaryOp { op: BinaryOperator::Or, .. } => {
178                if std::env::var("JOIN_REORDER_VERBOSE").is_ok() {
179                    eprintln!("[ANALYZER] Analyzing OR expression for common join conditions");
180                }
181
182                // Collect all OR branches into a list
183                let mut branches = Vec::new();
184                self.collect_or_branches(expr, &mut branches);
185
186                if std::env::var("JOIN_REORDER_VERBOSE").is_ok() {
187                    eprintln!("[ANALYZER] Found {} OR branches", branches.len());
188                }
189
190                // Extract edges from each branch
191                let mut branch_edges: Vec<Vec<JoinEdge>> = Vec::new();
192                for branch in &branches {
193                    let mut branch_analyzer = JoinOrderAnalyzer::new();
194                    let table_vec: Vec<String> = tables.iter().cloned().collect();
195                    branch_analyzer.register_tables(table_vec);
196                    branch_analyzer.analyze_predicate(branch, tables);
197                    branch_edges.push(branch_analyzer.edges().to_vec());
198                }
199
200                // Find edges common to ALL branches
201                if !branch_edges.is_empty() {
202                    let common_edges = self.find_common_edges(&branch_edges);
203
204                    if std::env::var("JOIN_REORDER_VERBOSE").is_ok() {
205                        eprintln!("[ANALYZER] Found {} common join edges across all OR branches", common_edges.len());
206                    }
207
208                    // Add common edges to our join graph
209                    for edge in common_edges {
210                        if std::env::var("JOIN_REORDER_VERBOSE").is_ok() {
211                            eprintln!("[ANALYZER] Added common edge from OR: {}.{} = {}.{}",
212                                edge.left_table, edge.left_column, edge.right_table, edge.right_column);
213                        }
214                        self.edges.push(edge);
215                    }
216                }
217            }
218            // Handle simple binary equality operations
219            Expression::BinaryOp { op: BinaryOperator::Equal, left, right } => {
220                let (left_table, left_col) = self.extract_column_ref(left, tables);
221                let (right_table, right_col) = self.extract_column_ref(right, tables);
222
223                if std::env::var("JOIN_REORDER_VERBOSE").is_ok() {
224                    eprintln!("[ANALYZER] analyze_predicate: left_table={:?}, right_table={:?}, left_col={:?}, right_col={:?}",
225                        left_table, right_table, left_col, right_col);
226                }
227
228                match (left_table, right_table, left_col, right_col) {
229                    // Equijoin: column from one table = column from another
230                    (Some(lt), Some(rt), Some(lc), Some(rc)) if lt != rt => {
231                        let edge = JoinEdge {
232                            left_table: lt.to_lowercase(),
233                            left_column: lc.clone(),
234                            right_table: rt.to_lowercase(),
235                            right_column: rc.clone(),
236                            join_type: join_type.clone(),
237                        };
238                        if std::env::var("JOIN_REORDER_VERBOSE").is_ok() {
239                            eprintln!("[ANALYZER] Added edge: {}.{} = {}.{} (join_type: {:?})", lt, lc, rt, rc, join_type);
240                        }
241                        self.edges.push(edge);
242                    }
243                    // Local predicate: column = constant
244                    (Some(table), None, Some(_col), _) => {
245                        if let Some(table_info) = self.tables.get_mut(&table.to_lowercase()) {
246                            table_info.local_predicates.push(expr.clone());
247                            // Heuristic: equality predicate has ~10% selectivity
248                            table_info.local_selectivity *= 0.1;
249                        }
250                    }
251                    _ => {
252                        if std::env::var("JOIN_REORDER_VERBOSE").is_ok() {
253                            eprintln!("[ANALYZER] Skipped predicate (no match)");
254                        }
255                    }
256                }
257            }
258            // For other operators, analyze for local vs cross-table
259            _ => {
260                // Conservative: mark as complex, don't try to optimize
261            }
262        }
263    }
264
265    /// Collect all branches of an OR expression into a flat list
266    /// Handles nested ORs by flattening them: (A OR B) OR C => [A, B, C]
267    #[allow(clippy::only_used_in_recursion)]
268    fn collect_or_branches(&self, expr: &Expression, branches: &mut Vec<Expression>) {
269        match expr {
270            Expression::BinaryOp { op: BinaryOperator::Or, left, right } => {
271                // Recursively collect from both sides
272                self.collect_or_branches(left, branches);
273                self.collect_or_branches(right, branches);
274            }
275            _ => {
276                // Leaf node - add to branches
277                branches.push(expr.clone());
278            }
279        }
280    }
281
282    /// Find join edges that appear in ALL branches
283    /// An edge is common if it has the same tables and columns in every branch
284    fn find_common_edges(&self, branch_edges: &[Vec<JoinEdge>]) -> Vec<JoinEdge> {
285        if branch_edges.is_empty() {
286            return Vec::new();
287        }
288
289        // Start with edges from first branch
290        let mut common_edges = Vec::new();
291        let first_branch = &branch_edges[0];
292
293        for edge in first_branch {
294            // Check if this edge appears in all other branches
295            let appears_in_all = branch_edges[1..].iter().all(|branch| {
296                branch.iter().any(|e| self.edges_match(e, edge))
297            });
298
299            if appears_in_all {
300                common_edges.push(edge.clone());
301            }
302        }
303
304        common_edges
305    }
306
307    /// Check if two edges represent the same join condition
308    /// Handles both (A=B) and (B=A) as equivalent
309    fn edges_match(&self, e1: &JoinEdge, e2: &JoinEdge) -> bool {
310        // Direct match: left-to-left, right-to-right
311        let direct = e1.left_table.eq_ignore_ascii_case(&e2.left_table)
312            && e1.left_column.eq_ignore_ascii_case(&e2.left_column)
313            && e1.right_table.eq_ignore_ascii_case(&e2.right_table)
314            && e1.right_column.eq_ignore_ascii_case(&e2.right_column);
315
316        // Reverse match: left-to-right, right-to-left (handles A=B vs B=A)
317        let reverse = e1.left_table.eq_ignore_ascii_case(&e2.right_table)
318            && e1.left_column.eq_ignore_ascii_case(&e2.right_column)
319            && e1.right_table.eq_ignore_ascii_case(&e2.left_table)
320            && e1.right_column.eq_ignore_ascii_case(&e2.left_column);
321
322        direct || reverse
323    }
324
325    /// Extract table and column info from an expression
326    /// Returns (table_name, column_name)
327    /// Uses table inference if explicit table prefix is not present
328    fn extract_column_ref(
329        &self,
330        expr: &Expression,
331        tables: &HashSet<String>,
332    ) -> (Option<String>, Option<String>) {
333        match expr {
334            Expression::ColumnRef { table, column } => {
335                // If explicit table prefix exists, use it
336                if let Some(t) = table {
337                    return (Some(t.clone()), Some(column.clone()));
338                }
339
340                // Otherwise, infer table from column prefix
341                let inferred_table = self.infer_table_from_column(column, tables);
342                (inferred_table, Some(column.clone()))
343            }
344            Expression::Literal(_) => (None, None),
345            _ => (None, None),
346        }
347    }
348
349    /// Infer table name from column name using schema-based lookup
350    ///
351    /// Uses the schema-based column-to-table map to resolve column references.
352    /// All column resolution relies solely on actual database schema metadata.
353    fn infer_table_from_column(&self, column: &str, tables: &HashSet<String>) -> Option<String> {
354        // Schema-based lookup only - no heuristic fallbacks
355        if self.column_to_table.is_empty() {
356            if std::env::var("JOIN_REORDER_VERBOSE").is_ok() {
357                eprintln!("[ANALYZER] Warning: No column-to-table map available for column {}", column);
358            }
359            return None;
360        }
361
362        let col_lower = column.to_lowercase();
363        if let Some(table) = self.column_to_table.get(&col_lower) {
364            if std::env::var("JOIN_REORDER_VERBOSE").is_ok() {
365                eprintln!("[ANALYZER] Schema lookup: {} -> {}", column, table);
366            }
367            // Verify the table is in our set (could be aliased)
368            if tables.contains(table) {
369                return Some(table.clone());
370            }
371            // Try case-insensitive match
372            for t in tables {
373                if t.eq_ignore_ascii_case(table) {
374                    return Some(t.clone());
375                }
376            }
377            if std::env::var("JOIN_REORDER_VERBOSE").is_ok() {
378                eprintln!("[ANALYZER] Warning: Table {} not in tables set {:?}", table, tables);
379            }
380        } else if std::env::var("JOIN_REORDER_VERBOSE").is_ok() {
381            eprintln!("[ANALYZER] Warning: Column {} not found in schema map (available: {:?})",
382                col_lower, self.column_to_table.keys().take(10).collect::<Vec<_>>());
383        }
384
385        None
386    }
387
388    /// Find all tables that have local predicates (highest selectivity filters)
389    pub fn find_most_selective_tables(&self) -> Vec<String> {
390        let mut tables: Vec<_> =
391            self.tables.values().filter(|t| !t.local_predicates.is_empty()).collect();
392
393        // Sort by selectivity (most selective first)
394        tables.sort_by(|a, b| {
395            a.local_selectivity.partial_cmp(&b.local_selectivity).unwrap_or(Ordering::Equal)
396        });
397
398        tables.iter().map(|t| t.name.clone()).collect()
399    }
400
401    /// Build a join chain starting from a seed table
402    /// Returns list of tables in optimal join order
403    pub fn build_join_chain(&self, seed_table: &str) -> Vec<String> {
404        let mut chain = vec![seed_table.to_lowercase()];
405        let mut visited = HashSet::new();
406        visited.insert(seed_table.to_lowercase());
407
408        // Greedy: follow edges from current table
409        while chain.len() < self.tables.len() {
410            let current_table = chain[chain.len() - 1].clone();
411
412            // Find an edge from current table
413            let mut next_table: Option<String> = None;
414            for edge in &self.edges {
415                if edge.left_table == current_table && !visited.contains(&edge.right_table) {
416                    next_table = Some(edge.right_table.clone());
417                    break;
418                } else if edge.right_table == current_table && !visited.contains(&edge.left_table) {
419                    next_table = Some(edge.left_table.clone());
420                    break;
421                }
422            }
423
424            // If no edge found, pick any unvisited table
425            if next_table.is_none() {
426                for table in self.tables.keys() {
427                    if !visited.contains(table) {
428                        next_table = Some(table.clone());
429                        break;
430                    }
431                }
432            }
433
434            if let Some(table) = next_table {
435                chain.push(table.clone());
436                visited.insert(table);
437            } else {
438                break;
439            }
440        }
441
442        chain
443    }
444
445    /// Find optimal join order given all constraints
446    ///
447    /// Uses heuristic: start with most selective local filters,
448    /// then follow equijoin chains
449    pub fn find_optimal_order(&self) -> Vec<String> {
450        // Find most selective tables (those with local predicates)
451        let selective_tables = self.find_most_selective_tables();
452
453        // Start with most selective, build chain
454        if let Some(seed) = selective_tables.first() {
455            self.build_join_chain(seed)
456        } else {
457            // Fallback: just use first table
458            if let Some(table) = self.tables.keys().next() {
459                self.build_join_chain(table)
460            } else {
461                Vec::new()
462            }
463        }
464    }
465
466    /// Get the equijoin edges that connect two specific tables
467    pub fn get_join_condition(
468        &self,
469        left_table: &str,
470        right_table: &str,
471    ) -> Option<(String, String)> {
472        let left_lower = left_table.to_lowercase();
473        let right_lower = right_table.to_lowercase();
474
475        for edge in &self.edges {
476            if (edge.left_table == left_lower && edge.right_table == right_lower)
477                || (edge.left_table == right_lower && edge.right_table == left_lower)
478            {
479                return Some((edge.left_column.clone(), edge.right_column.clone()));
480            }
481        }
482        None
483    }
484
485    /// Get all equijoin edges
486    pub fn edges(&self) -> &[JoinEdge] {
487        &self.edges
488    }
489
490    /// Get all tables registered in this analyzer
491    pub fn tables(&self) -> std::collections::BTreeSet<String> {
492        self.tables.keys().cloned().collect()
493    }
494
495    /// Add a join edge (for testing)
496    #[cfg(test)]
497    pub fn add_edge(&mut self, edge: JoinEdge) {
498        self.edges.push(edge);
499    }
500}
501
502#[cfg(test)]
503mod tests {
504    use super::*;
505
506    #[test]
507    fn test_join_edge_involvement() {
508        let edge = JoinEdge {
509            left_table: "t1".to_string(),
510            left_column: "a".to_string(),
511            right_table: "t2".to_string(),
512            right_column: "b".to_string(),
513            join_type: vibesql_ast::JoinType::Inner,
514        };
515
516        assert!(edge.involves_table("t1"));
517        assert!(edge.involves_table("t2"));
518        assert!(!edge.involves_table("t3"));
519    }
520
521    #[test]
522    fn test_join_edge_other_table() {
523        let edge = JoinEdge {
524            left_table: "t1".to_string(),
525            left_column: "a".to_string(),
526            right_table: "t2".to_string(),
527            right_column: "b".to_string(),
528            join_type: vibesql_ast::JoinType::Inner,
529        };
530
531        assert_eq!(edge.other_table("t1"), Some("t2".to_string()));
532        assert_eq!(edge.other_table("t2"), Some("t1".to_string()));
533        assert_eq!(edge.other_table("t3"), None);
534    }
535
536    #[test]
537    fn test_basic_chain_detection() {
538        let mut analyzer = JoinOrderAnalyzer::new();
539        analyzer.register_tables(vec!["t1".to_string(), "t2".to_string(), "t3".to_string()]);
540
541        // Add edges: t1-t2, t2-t3
542        analyzer.edges.push(JoinEdge {
543            left_table: "t1".to_string(),
544            left_column: "id".to_string(),
545            right_table: "t2".to_string(),
546            right_column: "id".to_string(),
547            join_type: vibesql_ast::JoinType::Inner,
548        });
549        analyzer.edges.push(JoinEdge {
550            left_table: "t2".to_string(),
551            left_column: "id".to_string(),
552            right_table: "t3".to_string(),
553            right_column: "id".to_string(),
554            join_type: vibesql_ast::JoinType::Inner,
555        });
556
557        let chain = analyzer.build_join_chain("t1");
558        assert_eq!(chain.len(), 3);
559        assert_eq!(chain[0], "t1");
560        // Should follow edges: t1 -> t2 -> t3
561    }
562
563    #[test]
564    fn test_most_selective_tables() {
565        let mut analyzer = JoinOrderAnalyzer::new();
566        analyzer.register_tables(vec!["t1".to_string(), "t2".to_string(), "t3".to_string()]);
567
568        // Create dummy predicates
569        let dummy_pred = Expression::Literal(vibesql_types::SqlValue::Integer(5));
570
571        // Add local predicates to t1 (most selective)
572        if let Some(table_info) = analyzer.tables.get_mut("t1") {
573            table_info.local_predicates.push(dummy_pred.clone());
574            table_info.local_selectivity = 0.1;
575        }
576
577        // Add local predicate to t2 (less selective)
578        if let Some(table_info) = analyzer.tables.get_mut("t2") {
579            table_info.local_predicates.push(dummy_pred.clone());
580            table_info.local_selectivity = 0.5;
581        }
582
583        let selective = analyzer.find_most_selective_tables();
584        assert_eq!(selective[0], "t1"); // Most selective first
585    }
586
587    #[test]
588    fn test_join_condition_lookup() {
589        let mut analyzer = JoinOrderAnalyzer::new();
590        analyzer.register_tables(vec!["t1".to_string(), "t2".to_string()]);
591
592        analyzer.edges.push(JoinEdge {
593            left_table: "t1".to_string(),
594            left_column: "id".to_string(),
595            right_table: "t2".to_string(),
596            right_column: "id".to_string(),
597            join_type: vibesql_ast::JoinType::Inner,
598        });
599
600        let condition = analyzer.get_join_condition("t1", "t2");
601        assert!(condition.is_some());
602        assert_eq!(condition.unwrap(), ("id".to_string(), "id".to_string()));
603    }
604
605    #[test]
606    fn test_case_insensitive_tables() {
607        let mut analyzer = JoinOrderAnalyzer::new();
608        analyzer.register_tables(vec!["T1".to_string(), "T2".to_string()]);
609
610        analyzer.edges.push(JoinEdge {
611            left_table: "t1".to_string(),
612            left_column: "id".to_string(),
613            right_table: "t2".to_string(),
614            right_column: "id".to_string(),
615            join_type: vibesql_ast::JoinType::Inner,
616        });
617
618        // Should find condition even with case differences
619        let condition = analyzer.get_join_condition("T1", "T2");
620        assert!(condition.is_some());
621    }
622}