vibesql_executor/select/join/
reorder.rs

1//! Join order optimization using selectivity-based heuristics
2//!
3//! This module analyzes predicates to determine optimal join orderings.
4//! The goal is to maximize row reduction early in the join chain to minimize
5//! intermediate result size during cascade joins.
6
7//! ## Example
8//!
9//! ```text
10//! SELECT * FROM t1, t2, t3, t4, t5, t6, t7, t8, t9, t10
11//! WHERE a1 = 5                    -- local predicate for t1
12//!   AND a1 = b2                   -- equijoin t1-t2
13//!   AND a2 = b3                   -- equijoin (t1 ∪ t2)-t3
14//!   AND a3 = b4                   -- equijoin (t1 ∪ t2 ∪ t3)-t4
15//!   ...
16//!
17//! Default order (left-to-right cascade):
18//! ((((((((((t1 JOIN t2) JOIN t3) JOIN t4) ... JOIN t10)
19//!
20//! Result: 90-row intermediates at each step (9 rows × 10 rows before equijoin filter)
21//!
22//! Optimal order (with selectivity awareness):
23//! Start with most selective: t1 (filtered to ~1 row by a1=5)
24//! Then: t1 JOIN t2 (1 × 10 = 10 intermediate, filtered to ~1 by a1=b2)
25//! Then: result JOIN t3 (1 × 10 = 10 intermediate, filtered to ~1 by a2=b3)
26//! ...
27//!
28//! Result: 10-row intermediates maximum (much better memory usage)
29//! ```
30
31use std::{
32    cmp::Ordering,
33    collections::{HashMap, HashSet},
34};
35
36use vibesql_ast::{BinaryOperator, Expression};
37
38/// Information about a table and its predicates
39#[derive(Debug, Clone)]
40struct TableInfo {
41    name: String,
42    local_predicates: Vec<Expression>, // Predicates that only reference this table
43    local_selectivity: f64,            // Estimated selectivity of local predicates (0.0-1.0)
44}
45
46/// Information about an equijoin between two tables
47#[derive(Debug, Clone, PartialEq)]
48pub struct JoinEdge {
49    /// Table name on left side of equijoin
50    pub left_table: String,
51    /// Column from left table
52    pub left_column: String,
53    /// Table name on right side of equijoin
54    pub right_table: String,
55    /// Column from right table
56    pub right_column: String,
57    /// Join type (INNER, SEMI, ANTI, etc.)
58    pub join_type: vibesql_ast::JoinType,
59}
60
61impl JoinEdge {
62    /// Check if this edge involves a specific table
63    pub fn involves_table(&self, table: &str) -> bool {
64        self.left_table.eq_ignore_ascii_case(table) || self.right_table.eq_ignore_ascii_case(table)
65    }
66
67    /// Get the other table in this edge (if input is one side)
68    pub fn other_table(&self, table: &str) -> Option<String> {
69        if self.left_table.eq_ignore_ascii_case(table) {
70            Some(self.right_table.clone())
71        } else if self.right_table.eq_ignore_ascii_case(table) {
72            Some(self.left_table.clone())
73        } else {
74            None
75        }
76    }
77}
78
79/// Selectivity information for a predicate
80#[derive(Debug, Clone)]
81pub struct Selectivity {
82    /// Estimated selectivity (0.0 = filters everything, 1.0 = no filtering)
83    pub factor: f64,
84    /// Type of selectivity (local vs equijoin)
85    pub predicate_type: PredicateType,
86}
87
88/// Classification of predicates by type
89#[derive(Debug, Clone, PartialEq, Eq)]
90pub enum PredicateType {
91    /// Predicate on single table (e.g., a1 > 5)
92    Local,
93    /// Equijoin between two tables (e.g., a1 = b2)
94    Equijoin,
95    /// Complex predicate involving multiple tables
96    Complex,
97}
98
99/// Analyzes join chains and determines optimal join ordering
100#[derive(Debug, Clone)]
101pub struct JoinOrderAnalyzer {
102    /// Mapping from table name to table info
103    tables: HashMap<String, TableInfo>,
104    /// List of equijoin edges discovered
105    edges: Vec<JoinEdge>,
106    /// Selectivity information for each predicate
107    #[allow(dead_code)]
108    selectivity: HashMap<String, Selectivity>,
109    /// Schema-based column-to-table mapping for resolving unqualified columns
110    column_to_table: HashMap<String, String>,
111}
112
113impl Default for JoinOrderAnalyzer {
114    fn default() -> Self {
115        Self::new()
116    }
117}
118
119impl JoinOrderAnalyzer {
120    /// Create a new join order analyzer
121    pub fn new() -> Self {
122        Self {
123            tables: HashMap::new(),
124            edges: Vec::new(),
125            selectivity: HashMap::new(),
126            column_to_table: HashMap::new(),
127        }
128    }
129
130    /// Create a new join order analyzer with schema-based column resolution
131    pub fn with_column_map(column_to_table: HashMap<String, String>) -> Self {
132        Self {
133            tables: HashMap::new(),
134            edges: Vec::new(),
135            selectivity: HashMap::new(),
136            column_to_table,
137        }
138    }
139
140    /// Set the column-to-table mapping for schema-based resolution
141    pub fn set_column_map(&mut self, column_to_table: HashMap<String, String>) {
142        self.column_to_table = column_to_table;
143    }
144
145    /// Register all tables involved in the query
146    pub fn register_tables(&mut self, table_names: Vec<String>) {
147        for name in table_names {
148            self.tables.insert(
149                name.to_lowercase(),
150                TableInfo {
151                    name: name.to_lowercase(),
152                    local_predicates: Vec::new(),
153                    local_selectivity: 1.0,
154                },
155            );
156        }
157    }
158
159    /// Analyze a predicate and extract join edges or local predicates
160    pub fn analyze_predicate(&mut self, expr: &Expression, tables: &HashSet<String>) {
161        self.analyze_predicate_with_type(expr, tables, vibesql_ast::JoinType::Inner);
162    }
163
164    /// Analyze a predicate with an explicit join type
165    pub fn analyze_predicate_with_type(
166        &mut self,
167        expr: &Expression,
168        tables: &HashSet<String>,
169        join_type: vibesql_ast::JoinType,
170    ) {
171        match expr {
172            // Recursively handle AND expressions
173            Expression::BinaryOp { op: BinaryOperator::And, left, right } => {
174                if std::env::var("JOIN_REORDER_VERBOSE").is_ok() {
175                    eprintln!("[ANALYZER] Decomposing AND expression");
176                }
177                self.analyze_predicate_with_type(left, tables, join_type.clone());
178                self.analyze_predicate_with_type(right, tables, join_type);
179            }
180            // Handle OR expressions by extracting common join conditions
181            // that appear in ALL branches
182            Expression::BinaryOp { op: BinaryOperator::Or, .. } => {
183                if std::env::var("JOIN_REORDER_VERBOSE").is_ok() {
184                    eprintln!("[ANALYZER] Analyzing OR expression for common join conditions");
185                }
186
187                // Collect all OR branches into a list
188                let mut branches = Vec::new();
189                self.collect_or_branches(expr, &mut branches);
190
191                if std::env::var("JOIN_REORDER_VERBOSE").is_ok() {
192                    eprintln!("[ANALYZER] Found {} OR branches", branches.len());
193                }
194
195                // Extract edges from each branch
196                let mut branch_edges: Vec<Vec<JoinEdge>> = Vec::new();
197                for branch in &branches {
198                    let mut branch_analyzer = JoinOrderAnalyzer::new();
199                    let table_vec: Vec<String> = tables.iter().cloned().collect();
200                    branch_analyzer.register_tables(table_vec);
201                    branch_analyzer.analyze_predicate(branch, tables);
202                    branch_edges.push(branch_analyzer.edges().to_vec());
203                }
204
205                // Find edges common to ALL branches
206                if !branch_edges.is_empty() {
207                    let common_edges = self.find_common_edges(&branch_edges);
208
209                    if std::env::var("JOIN_REORDER_VERBOSE").is_ok() {
210                        eprintln!(
211                            "[ANALYZER] Found {} common join edges across all OR branches",
212                            common_edges.len()
213                        );
214                    }
215
216                    // Add common edges to our join graph
217                    for edge in common_edges {
218                        if std::env::var("JOIN_REORDER_VERBOSE").is_ok() {
219                            eprintln!(
220                                "[ANALYZER] Added common edge from OR: {}.{} = {}.{}",
221                                edge.left_table,
222                                edge.left_column,
223                                edge.right_table,
224                                edge.right_column
225                            );
226                        }
227                        self.edges.push(edge);
228                    }
229                }
230            }
231            // Handle simple binary equality operations
232            Expression::BinaryOp { op: BinaryOperator::Equal, left, right } => {
233                let (left_table, left_col) = self.extract_column_ref(left, tables);
234                let (right_table, right_col) = self.extract_column_ref(right, tables);
235
236                if std::env::var("JOIN_REORDER_VERBOSE").is_ok() {
237                    eprintln!("[ANALYZER] analyze_predicate: left_table={:?}, right_table={:?}, left_col={:?}, right_col={:?}",
238                        left_table, right_table, left_col, right_col);
239                }
240
241                match (left_table, right_table, left_col, right_col) {
242                    // Equijoin: column from one table = column from another
243                    (Some(lt), Some(rt), Some(lc), Some(rc)) if lt != rt => {
244                        let edge = JoinEdge {
245                            left_table: lt.to_lowercase(),
246                            left_column: lc.clone(),
247                            right_table: rt.to_lowercase(),
248                            right_column: rc.clone(),
249                            join_type: join_type.clone(),
250                        };
251                        if std::env::var("JOIN_REORDER_VERBOSE").is_ok() {
252                            eprintln!(
253                                "[ANALYZER] Added edge: {}.{} = {}.{} (join_type: {:?})",
254                                lt, lc, rt, rc, join_type
255                            );
256                        }
257                        self.edges.push(edge);
258                    }
259                    // Local predicate: column = constant
260                    (Some(table), None, Some(_col), _) => {
261                        if let Some(table_info) = self.tables.get_mut(&table.to_lowercase()) {
262                            table_info.local_predicates.push(expr.clone());
263                            // Heuristic: equality predicate has ~10% selectivity
264                            table_info.local_selectivity *= 0.1;
265                        }
266                    }
267                    _ => {
268                        if std::env::var("JOIN_REORDER_VERBOSE").is_ok() {
269                            eprintln!("[ANALYZER] Skipped predicate (no match)");
270                        }
271                    }
272                }
273            }
274            // For other operators, analyze for local vs cross-table
275            _ => {
276                // Conservative: mark as complex, don't try to optimize
277            }
278        }
279    }
280
281    /// Collect all branches of an OR expression into a flat list
282    /// Handles nested ORs by flattening them: (A OR B) OR C => [A, B, C]
283    #[allow(clippy::only_used_in_recursion)]
284    fn collect_or_branches(&self, expr: &Expression, branches: &mut Vec<Expression>) {
285        match expr {
286            Expression::BinaryOp { op: BinaryOperator::Or, left, right } => {
287                // Recursively collect from both sides
288                self.collect_or_branches(left, branches);
289                self.collect_or_branches(right, branches);
290            }
291            _ => {
292                // Leaf node - add to branches
293                branches.push(expr.clone());
294            }
295        }
296    }
297
298    /// Find join edges that appear in ALL branches
299    /// An edge is common if it has the same tables and columns in every branch
300    fn find_common_edges(&self, branch_edges: &[Vec<JoinEdge>]) -> Vec<JoinEdge> {
301        if branch_edges.is_empty() {
302            return Vec::new();
303        }
304
305        // Start with edges from first branch
306        let mut common_edges = Vec::new();
307        let first_branch = &branch_edges[0];
308
309        for edge in first_branch {
310            // Check if this edge appears in all other branches
311            let appears_in_all = branch_edges[1..]
312                .iter()
313                .all(|branch| branch.iter().any(|e| self.edges_match(e, edge)));
314
315            if appears_in_all {
316                common_edges.push(edge.clone());
317            }
318        }
319
320        common_edges
321    }
322
323    /// Check if two edges represent the same join condition
324    /// Handles both (A=B) and (B=A) as equivalent
325    fn edges_match(&self, e1: &JoinEdge, e2: &JoinEdge) -> bool {
326        // Direct match: left-to-left, right-to-right
327        let direct = e1.left_table.eq_ignore_ascii_case(&e2.left_table)
328            && e1.left_column.eq_ignore_ascii_case(&e2.left_column)
329            && e1.right_table.eq_ignore_ascii_case(&e2.right_table)
330            && e1.right_column.eq_ignore_ascii_case(&e2.right_column);
331
332        // Reverse match: left-to-right, right-to-left (handles A=B vs B=A)
333        let reverse = e1.left_table.eq_ignore_ascii_case(&e2.right_table)
334            && e1.left_column.eq_ignore_ascii_case(&e2.right_column)
335            && e1.right_table.eq_ignore_ascii_case(&e2.left_table)
336            && e1.right_column.eq_ignore_ascii_case(&e2.left_column);
337
338        direct || reverse
339    }
340
341    /// Extract table and column info from an expression
342    /// Returns (table_name, column_name)
343    /// Uses table inference if explicit table prefix is not present
344    fn extract_column_ref(
345        &self,
346        expr: &Expression,
347        tables: &HashSet<String>,
348    ) -> (Option<String>, Option<String>) {
349        match expr {
350            Expression::ColumnRef { table, column } => {
351                // If explicit table prefix exists, use it
352                if let Some(t) = table {
353                    return (Some(t.clone()), Some(column.clone()));
354                }
355
356                // Otherwise, infer table from column prefix
357                let inferred_table = self.infer_table_from_column(column, tables);
358                (inferred_table, Some(column.clone()))
359            }
360            Expression::Literal(_) => (None, None),
361            _ => (None, None),
362        }
363    }
364
365    /// Infer table name from column name using schema-based lookup
366    ///
367    /// Uses the schema-based column-to-table map to resolve column references.
368    /// All column resolution relies solely on actual database schema metadata.
369    fn infer_table_from_column(&self, column: &str, tables: &HashSet<String>) -> Option<String> {
370        // Schema-based lookup only - no heuristic fallbacks
371        if self.column_to_table.is_empty() {
372            if std::env::var("JOIN_REORDER_VERBOSE").is_ok() {
373                eprintln!(
374                    "[ANALYZER] Warning: No column-to-table map available for column {}",
375                    column
376                );
377            }
378            return None;
379        }
380
381        let col_lower = column.to_lowercase();
382        if let Some(table) = self.column_to_table.get(&col_lower) {
383            if std::env::var("JOIN_REORDER_VERBOSE").is_ok() {
384                eprintln!("[ANALYZER] Schema lookup: {} -> {}", column, table);
385            }
386            // Verify the table is in our set (could be aliased)
387            if tables.contains(table) {
388                return Some(table.clone());
389            }
390            // Try case-insensitive match
391            for t in tables {
392                if t.eq_ignore_ascii_case(table) {
393                    return Some(t.clone());
394                }
395            }
396            if std::env::var("JOIN_REORDER_VERBOSE").is_ok() {
397                eprintln!("[ANALYZER] Warning: Table {} not in tables set {:?}", table, tables);
398            }
399        } else if std::env::var("JOIN_REORDER_VERBOSE").is_ok() {
400            eprintln!(
401                "[ANALYZER] Warning: Column {} not found in schema map (available: {:?})",
402                col_lower,
403                self.column_to_table.keys().take(10).collect::<Vec<_>>()
404            );
405        }
406
407        None
408    }
409
410    /// Find all tables that have local predicates (highest selectivity filters)
411    pub fn find_most_selective_tables(&self) -> Vec<String> {
412        let mut tables: Vec<_> =
413            self.tables.values().filter(|t| !t.local_predicates.is_empty()).collect();
414
415        // Sort by selectivity (most selective first)
416        tables.sort_by(|a, b| {
417            a.local_selectivity.partial_cmp(&b.local_selectivity).unwrap_or(Ordering::Equal)
418        });
419
420        tables.iter().map(|t| t.name.clone()).collect()
421    }
422
423    /// Build a join chain starting from a seed table
424    /// Returns list of tables in optimal join order
425    pub fn build_join_chain(&self, seed_table: &str) -> Vec<String> {
426        let mut chain = vec![seed_table.to_lowercase()];
427        let mut visited = HashSet::new();
428        visited.insert(seed_table.to_lowercase());
429
430        // Greedy: follow edges from current table
431        while chain.len() < self.tables.len() {
432            let current_table = chain[chain.len() - 1].clone();
433
434            // Find an edge from current table
435            let mut next_table: Option<String> = None;
436            for edge in &self.edges {
437                if edge.left_table == current_table && !visited.contains(&edge.right_table) {
438                    next_table = Some(edge.right_table.clone());
439                    break;
440                } else if edge.right_table == current_table && !visited.contains(&edge.left_table) {
441                    next_table = Some(edge.left_table.clone());
442                    break;
443                }
444            }
445
446            // If no edge found, pick any unvisited table
447            if next_table.is_none() {
448                for table in self.tables.keys() {
449                    if !visited.contains(table) {
450                        next_table = Some(table.clone());
451                        break;
452                    }
453                }
454            }
455
456            if let Some(table) = next_table {
457                chain.push(table.clone());
458                visited.insert(table);
459            } else {
460                break;
461            }
462        }
463
464        chain
465    }
466
467    /// Find optimal join order given all constraints
468    ///
469    /// Uses heuristic: start with most selective local filters,
470    /// then follow equijoin chains
471    pub fn find_optimal_order(&self) -> Vec<String> {
472        // Find most selective tables (those with local predicates)
473        let selective_tables = self.find_most_selective_tables();
474
475        // Start with most selective, build chain
476        if let Some(seed) = selective_tables.first() {
477            self.build_join_chain(seed)
478        } else {
479            // Fallback: just use first table
480            if let Some(table) = self.tables.keys().next() {
481                self.build_join_chain(table)
482            } else {
483                Vec::new()
484            }
485        }
486    }
487
488    /// Get the equijoin edges that connect two specific tables
489    pub fn get_join_condition(
490        &self,
491        left_table: &str,
492        right_table: &str,
493    ) -> Option<(String, String)> {
494        let left_lower = left_table.to_lowercase();
495        let right_lower = right_table.to_lowercase();
496
497        for edge in &self.edges {
498            if (edge.left_table == left_lower && edge.right_table == right_lower)
499                || (edge.left_table == right_lower && edge.right_table == left_lower)
500            {
501                return Some((edge.left_column.clone(), edge.right_column.clone()));
502            }
503        }
504        None
505    }
506
507    /// Get all equijoin edges
508    pub fn edges(&self) -> &[JoinEdge] {
509        &self.edges
510    }
511
512    /// Get all tables registered in this analyzer
513    pub fn tables(&self) -> std::collections::BTreeSet<String> {
514        self.tables.keys().cloned().collect()
515    }
516
517    /// Add a join edge (for testing)
518    #[cfg(test)]
519    pub fn add_edge(&mut self, edge: JoinEdge) {
520        self.edges.push(edge);
521    }
522}
523
524#[cfg(test)]
525mod tests {
526    use super::*;
527
528    #[test]
529    fn test_join_edge_involvement() {
530        let edge = JoinEdge {
531            left_table: "t1".to_string(),
532            left_column: "a".to_string(),
533            right_table: "t2".to_string(),
534            right_column: "b".to_string(),
535            join_type: vibesql_ast::JoinType::Inner,
536        };
537
538        assert!(edge.involves_table("t1"));
539        assert!(edge.involves_table("t2"));
540        assert!(!edge.involves_table("t3"));
541    }
542
543    #[test]
544    fn test_join_edge_other_table() {
545        let edge = JoinEdge {
546            left_table: "t1".to_string(),
547            left_column: "a".to_string(),
548            right_table: "t2".to_string(),
549            right_column: "b".to_string(),
550            join_type: vibesql_ast::JoinType::Inner,
551        };
552
553        assert_eq!(edge.other_table("t1"), Some("t2".to_string()));
554        assert_eq!(edge.other_table("t2"), Some("t1".to_string()));
555        assert_eq!(edge.other_table("t3"), None);
556    }
557
558    #[test]
559    fn test_basic_chain_detection() {
560        let mut analyzer = JoinOrderAnalyzer::new();
561        analyzer.register_tables(vec!["t1".to_string(), "t2".to_string(), "t3".to_string()]);
562
563        // Add edges: t1-t2, t2-t3
564        analyzer.edges.push(JoinEdge {
565            left_table: "t1".to_string(),
566            left_column: "id".to_string(),
567            right_table: "t2".to_string(),
568            right_column: "id".to_string(),
569            join_type: vibesql_ast::JoinType::Inner,
570        });
571        analyzer.edges.push(JoinEdge {
572            left_table: "t2".to_string(),
573            left_column: "id".to_string(),
574            right_table: "t3".to_string(),
575            right_column: "id".to_string(),
576            join_type: vibesql_ast::JoinType::Inner,
577        });
578
579        let chain = analyzer.build_join_chain("t1");
580        assert_eq!(chain.len(), 3);
581        assert_eq!(chain[0], "t1");
582        // Should follow edges: t1 -> t2 -> t3
583    }
584
585    #[test]
586    fn test_most_selective_tables() {
587        let mut analyzer = JoinOrderAnalyzer::new();
588        analyzer.register_tables(vec!["t1".to_string(), "t2".to_string(), "t3".to_string()]);
589
590        // Create dummy predicates
591        let dummy_pred = Expression::Literal(vibesql_types::SqlValue::Integer(5));
592
593        // Add local predicates to t1 (most selective)
594        if let Some(table_info) = analyzer.tables.get_mut("t1") {
595            table_info.local_predicates.push(dummy_pred.clone());
596            table_info.local_selectivity = 0.1;
597        }
598
599        // Add local predicate to t2 (less selective)
600        if let Some(table_info) = analyzer.tables.get_mut("t2") {
601            table_info.local_predicates.push(dummy_pred.clone());
602            table_info.local_selectivity = 0.5;
603        }
604
605        let selective = analyzer.find_most_selective_tables();
606        assert_eq!(selective[0], "t1"); // Most selective first
607    }
608
609    #[test]
610    fn test_join_condition_lookup() {
611        let mut analyzer = JoinOrderAnalyzer::new();
612        analyzer.register_tables(vec!["t1".to_string(), "t2".to_string()]);
613
614        analyzer.edges.push(JoinEdge {
615            left_table: "t1".to_string(),
616            left_column: "id".to_string(),
617            right_table: "t2".to_string(),
618            right_column: "id".to_string(),
619            join_type: vibesql_ast::JoinType::Inner,
620        });
621
622        let condition = analyzer.get_join_condition("t1", "t2");
623        assert!(condition.is_some());
624        assert_eq!(condition.unwrap(), ("id".to_string(), "id".to_string()));
625    }
626
627    #[test]
628    fn test_case_insensitive_tables() {
629        let mut analyzer = JoinOrderAnalyzer::new();
630        analyzer.register_tables(vec!["T1".to_string(), "T2".to_string()]);
631
632        analyzer.edges.push(JoinEdge {
633            left_table: "t1".to_string(),
634            left_column: "id".to_string(),
635            right_table: "t2".to_string(),
636            right_column: "id".to_string(),
637            join_type: vibesql_ast::JoinType::Inner,
638        });
639
640        // Should find condition even with case differences
641        let condition = analyzer.get_join_condition("T1", "T2");
642        assert!(condition.is_some());
643    }
644}