1use std::{
32 cmp::Ordering,
33 collections::{HashMap, HashSet},
34};
35
36use vibesql_ast::{BinaryOperator, Expression};
37
38#[derive(Debug, Clone)]
40struct TableInfo {
41 name: String,
42 local_predicates: Vec<Expression>, local_selectivity: f64, }
45
46#[derive(Debug, Clone, PartialEq)]
48pub struct JoinEdge {
49 pub left_table: String,
51 pub left_column: String,
53 pub right_table: String,
55 pub right_column: String,
57 pub join_type: vibesql_ast::JoinType,
59}
60
61impl JoinEdge {
62 pub fn involves_table(&self, table: &str) -> bool {
64 self.left_table.eq_ignore_ascii_case(table) || self.right_table.eq_ignore_ascii_case(table)
65 }
66
67 pub fn other_table(&self, table: &str) -> Option<String> {
69 if self.left_table.eq_ignore_ascii_case(table) {
70 Some(self.right_table.clone())
71 } else if self.right_table.eq_ignore_ascii_case(table) {
72 Some(self.left_table.clone())
73 } else {
74 None
75 }
76 }
77}
78
79#[derive(Debug, Clone)]
81pub struct Selectivity {
82 pub factor: f64,
84 pub predicate_type: PredicateType,
86}
87
88#[derive(Debug, Clone, PartialEq, Eq)]
90pub enum PredicateType {
91 Local,
93 Equijoin,
95 Complex,
97}
98
99#[derive(Debug, Clone)]
101pub struct JoinOrderAnalyzer {
102 tables: HashMap<String, TableInfo>,
104 edges: Vec<JoinEdge>,
106 #[allow(dead_code)]
108 selectivity: HashMap<String, Selectivity>,
109 column_to_table: HashMap<String, String>,
111}
112
113impl Default for JoinOrderAnalyzer {
114 fn default() -> Self {
115 Self::new()
116 }
117}
118
119impl JoinOrderAnalyzer {
120 pub fn new() -> Self {
122 Self {
123 tables: HashMap::new(),
124 edges: Vec::new(),
125 selectivity: HashMap::new(),
126 column_to_table: HashMap::new(),
127 }
128 }
129
130 pub fn with_column_map(column_to_table: HashMap<String, String>) -> Self {
132 Self {
133 tables: HashMap::new(),
134 edges: Vec::new(),
135 selectivity: HashMap::new(),
136 column_to_table,
137 }
138 }
139
140 pub fn set_column_map(&mut self, column_to_table: HashMap<String, String>) {
142 self.column_to_table = column_to_table;
143 }
144
145 pub fn register_tables(&mut self, table_names: Vec<String>) {
147 for name in table_names {
148 self.tables.insert(
149 name.to_lowercase(),
150 TableInfo {
151 name: name.to_lowercase(),
152 local_predicates: Vec::new(),
153 local_selectivity: 1.0,
154 },
155 );
156 }
157 }
158
159 pub fn analyze_predicate(&mut self, expr: &Expression, tables: &HashSet<String>) {
161 self.analyze_predicate_with_type(expr, tables, vibesql_ast::JoinType::Inner);
162 }
163
164 pub fn analyze_predicate_with_type(
166 &mut self,
167 expr: &Expression,
168 tables: &HashSet<String>,
169 join_type: vibesql_ast::JoinType,
170 ) {
171 match expr {
172 Expression::BinaryOp { op: BinaryOperator::And, left, right } => {
174 if std::env::var("JOIN_REORDER_VERBOSE").is_ok() {
175 eprintln!("[ANALYZER] Decomposing AND expression");
176 }
177 self.analyze_predicate_with_type(left, tables, join_type.clone());
178 self.analyze_predicate_with_type(right, tables, join_type);
179 }
180 Expression::BinaryOp { op: BinaryOperator::Or, .. } => {
183 if std::env::var("JOIN_REORDER_VERBOSE").is_ok() {
184 eprintln!("[ANALYZER] Analyzing OR expression for common join conditions");
185 }
186
187 let mut branches = Vec::new();
189 self.collect_or_branches(expr, &mut branches);
190
191 if std::env::var("JOIN_REORDER_VERBOSE").is_ok() {
192 eprintln!("[ANALYZER] Found {} OR branches", branches.len());
193 }
194
195 let mut branch_edges: Vec<Vec<JoinEdge>> = Vec::new();
197 for branch in &branches {
198 let mut branch_analyzer =
201 JoinOrderAnalyzer::with_column_map(self.column_to_table.clone());
202 let table_vec: Vec<String> = tables.iter().cloned().collect();
203 branch_analyzer.register_tables(table_vec);
204 branch_analyzer.analyze_predicate(branch, tables);
205 branch_edges.push(branch_analyzer.edges().to_vec());
206 }
207
208 if !branch_edges.is_empty() {
210 let common_edges = self.find_common_edges(&branch_edges);
211
212 if std::env::var("JOIN_REORDER_VERBOSE").is_ok() {
213 eprintln!(
214 "[ANALYZER] Found {} common join edges across all OR branches",
215 common_edges.len()
216 );
217 }
218
219 for edge in common_edges {
221 if std::env::var("JOIN_REORDER_VERBOSE").is_ok() {
222 eprintln!(
223 "[ANALYZER] Added common edge from OR: {}.{} = {}.{}",
224 edge.left_table,
225 edge.left_column,
226 edge.right_table,
227 edge.right_column
228 );
229 }
230 self.edges.push(edge);
231 }
232 }
233 }
234 Expression::BinaryOp { op: BinaryOperator::Equal, left, right } => {
236 let (left_table, left_col) = self.extract_column_ref(left, tables);
237 let (right_table, right_col) = self.extract_column_ref(right, tables);
238
239 if std::env::var("JOIN_REORDER_VERBOSE").is_ok() {
240 eprintln!("[ANALYZER] analyze_predicate: left_table={:?}, right_table={:?}, left_col={:?}, right_col={:?}",
241 left_table, right_table, left_col, right_col);
242 }
243
244 match (left_table, right_table, left_col, right_col) {
245 (Some(lt), Some(rt), Some(lc), Some(rc)) if lt != rt => {
247 let edge = JoinEdge {
248 left_table: lt.to_lowercase(),
249 left_column: lc.clone(),
250 right_table: rt.to_lowercase(),
251 right_column: rc.clone(),
252 join_type: join_type.clone(),
253 };
254 if std::env::var("JOIN_REORDER_VERBOSE").is_ok() {
255 eprintln!(
256 "[ANALYZER] Added edge: {}.{} = {}.{} (join_type: {:?})",
257 lt, lc, rt, rc, join_type
258 );
259 }
260 self.edges.push(edge);
261 }
262 (Some(table), None, Some(_col), _) => {
264 if let Some(table_info) = self.tables.get_mut(&table.to_lowercase()) {
265 table_info.local_predicates.push(expr.clone());
266 table_info.local_selectivity *= 0.1;
268 }
269 }
270 _ => {
271 if std::env::var("JOIN_REORDER_VERBOSE").is_ok() {
272 eprintln!("[ANALYZER] Skipped predicate (no match)");
273 }
274 }
275 }
276 }
277 _ => {
279 }
281 }
282 }
283
284 #[allow(clippy::only_used_in_recursion)]
287 fn collect_or_branches(&self, expr: &Expression, branches: &mut Vec<Expression>) {
288 match expr {
289 Expression::BinaryOp { op: BinaryOperator::Or, left, right } => {
290 self.collect_or_branches(left, branches);
292 self.collect_or_branches(right, branches);
293 }
294 _ => {
295 branches.push(expr.clone());
297 }
298 }
299 }
300
301 fn find_common_edges(&self, branch_edges: &[Vec<JoinEdge>]) -> Vec<JoinEdge> {
304 if branch_edges.is_empty() {
305 return Vec::new();
306 }
307
308 let mut common_edges = Vec::new();
310 let first_branch = &branch_edges[0];
311
312 for edge in first_branch {
313 let appears_in_all = branch_edges[1..]
315 .iter()
316 .all(|branch| branch.iter().any(|e| self.edges_match(e, edge)));
317
318 if appears_in_all {
319 common_edges.push(edge.clone());
320 }
321 }
322
323 common_edges
324 }
325
326 fn edges_match(&self, e1: &JoinEdge, e2: &JoinEdge) -> bool {
329 let direct = e1.left_table.eq_ignore_ascii_case(&e2.left_table)
331 && e1.left_column.eq_ignore_ascii_case(&e2.left_column)
332 && e1.right_table.eq_ignore_ascii_case(&e2.right_table)
333 && e1.right_column.eq_ignore_ascii_case(&e2.right_column);
334
335 let reverse = e1.left_table.eq_ignore_ascii_case(&e2.right_table)
337 && e1.left_column.eq_ignore_ascii_case(&e2.right_column)
338 && e1.right_table.eq_ignore_ascii_case(&e2.left_table)
339 && e1.right_column.eq_ignore_ascii_case(&e2.left_column);
340
341 direct || reverse
342 }
343
344 fn extract_column_ref(
348 &self,
349 expr: &Expression,
350 tables: &HashSet<String>,
351 ) -> (Option<String>, Option<String>) {
352 match expr {
353 Expression::ColumnRef(col_id) => {
354 let column = col_id.column_canonical();
355 if let Some(t) = col_id.table_canonical() {
357 return (Some(t.to_string()), Some(column.to_string()));
358 }
359
360 let inferred_table = self.infer_table_from_column(column, tables);
362 (inferred_table, Some(column.to_string()))
363 }
364 Expression::Literal(_) => (None, None),
365 _ => (None, None),
366 }
367 }
368
369 fn infer_table_from_column(&self, column: &str, tables: &HashSet<String>) -> Option<String> {
374 if self.column_to_table.is_empty() {
376 if std::env::var("JOIN_REORDER_VERBOSE").is_ok() {
377 eprintln!(
378 "[ANALYZER] Warning: No column-to-table map available for column {}",
379 column
380 );
381 }
382 return None;
383 }
384
385 let col_lower = column.to_lowercase();
386 if let Some(table) = self.column_to_table.get(&col_lower) {
387 if std::env::var("JOIN_REORDER_VERBOSE").is_ok() {
388 eprintln!("[ANALYZER] Schema lookup: {} -> {}", column, table);
389 }
390 if tables.contains(table) {
392 return Some(table.clone());
393 }
394 for t in tables {
396 if t.eq_ignore_ascii_case(table) {
397 return Some(t.clone());
398 }
399 }
400 if std::env::var("JOIN_REORDER_VERBOSE").is_ok() {
401 eprintln!("[ANALYZER] Warning: Table {} not in tables set {:?}", table, tables);
402 }
403 } else if std::env::var("JOIN_REORDER_VERBOSE").is_ok() {
404 eprintln!(
405 "[ANALYZER] Warning: Column {} not found in schema map (available: {:?})",
406 col_lower,
407 self.column_to_table.keys().take(10).collect::<Vec<_>>()
408 );
409 }
410
411 None
412 }
413
414 pub fn find_most_selective_tables(&self) -> Vec<String> {
416 let mut tables: Vec<_> =
417 self.tables.values().filter(|t| !t.local_predicates.is_empty()).collect();
418
419 tables.sort_by(|a, b| {
421 a.local_selectivity.partial_cmp(&b.local_selectivity).unwrap_or(Ordering::Equal)
422 });
423
424 tables.iter().map(|t| t.name.clone()).collect()
425 }
426
427 pub fn build_join_chain(&self, seed_table: &str) -> Vec<String> {
430 let mut chain = vec![seed_table.to_lowercase()];
431 let mut visited = HashSet::new();
432 visited.insert(seed_table.to_lowercase());
433
434 while chain.len() < self.tables.len() {
436 let current_table = chain[chain.len() - 1].clone();
437
438 let mut next_table: Option<String> = None;
440 for edge in &self.edges {
441 if edge.left_table == current_table && !visited.contains(&edge.right_table) {
442 next_table = Some(edge.right_table.clone());
443 break;
444 } else if edge.right_table == current_table && !visited.contains(&edge.left_table) {
445 next_table = Some(edge.left_table.clone());
446 break;
447 }
448 }
449
450 if next_table.is_none() {
452 for table in self.tables.keys() {
453 if !visited.contains(table) {
454 next_table = Some(table.clone());
455 break;
456 }
457 }
458 }
459
460 if let Some(table) = next_table {
461 chain.push(table.clone());
462 visited.insert(table);
463 } else {
464 break;
465 }
466 }
467
468 chain
469 }
470
471 pub fn find_optimal_order(&self) -> Vec<String> {
476 let selective_tables = self.find_most_selective_tables();
478
479 if let Some(seed) = selective_tables.first() {
481 self.build_join_chain(seed)
482 } else {
483 if let Some(table) = self.tables.keys().next() {
485 self.build_join_chain(table)
486 } else {
487 Vec::new()
488 }
489 }
490 }
491
492 pub fn get_join_condition(
494 &self,
495 left_table: &str,
496 right_table: &str,
497 ) -> Option<(String, String)> {
498 let left_lower = left_table.to_lowercase();
499 let right_lower = right_table.to_lowercase();
500
501 for edge in &self.edges {
502 if (edge.left_table == left_lower && edge.right_table == right_lower)
503 || (edge.left_table == right_lower && edge.right_table == left_lower)
504 {
505 return Some((edge.left_column.clone(), edge.right_column.clone()));
506 }
507 }
508 None
509 }
510
511 pub fn edges(&self) -> &[JoinEdge] {
513 &self.edges
514 }
515
516 pub fn tables(&self) -> std::collections::BTreeSet<String> {
518 self.tables.keys().cloned().collect()
519 }
520
521 #[cfg(test)]
523 pub fn add_edge(&mut self, edge: JoinEdge) {
524 self.edges.push(edge);
525 }
526}
527
528#[cfg(test)]
529mod tests {
530 use super::*;
531
532 #[test]
533 fn test_join_edge_involvement() {
534 let edge = JoinEdge {
535 left_table: "t1".to_string(),
536 left_column: "a".to_string(),
537 right_table: "t2".to_string(),
538 right_column: "b".to_string(),
539 join_type: vibesql_ast::JoinType::Inner,
540 };
541
542 assert!(edge.involves_table("t1"));
543 assert!(edge.involves_table("t2"));
544 assert!(!edge.involves_table("t3"));
545 }
546
547 #[test]
548 fn test_join_edge_other_table() {
549 let edge = JoinEdge {
550 left_table: "t1".to_string(),
551 left_column: "a".to_string(),
552 right_table: "t2".to_string(),
553 right_column: "b".to_string(),
554 join_type: vibesql_ast::JoinType::Inner,
555 };
556
557 assert_eq!(edge.other_table("t1"), Some("t2".to_string()));
558 assert_eq!(edge.other_table("t2"), Some("t1".to_string()));
559 assert_eq!(edge.other_table("t3"), None);
560 }
561
562 #[test]
563 fn test_basic_chain_detection() {
564 let mut analyzer = JoinOrderAnalyzer::new();
565 analyzer.register_tables(vec!["t1".to_string(), "t2".to_string(), "t3".to_string()]);
566
567 analyzer.edges.push(JoinEdge {
569 left_table: "t1".to_string(),
570 left_column: "id".to_string(),
571 right_table: "t2".to_string(),
572 right_column: "id".to_string(),
573 join_type: vibesql_ast::JoinType::Inner,
574 });
575 analyzer.edges.push(JoinEdge {
576 left_table: "t2".to_string(),
577 left_column: "id".to_string(),
578 right_table: "t3".to_string(),
579 right_column: "id".to_string(),
580 join_type: vibesql_ast::JoinType::Inner,
581 });
582
583 let chain = analyzer.build_join_chain("t1");
584 assert_eq!(chain.len(), 3);
585 assert_eq!(chain[0], "t1");
586 }
588
589 #[test]
590 fn test_most_selective_tables() {
591 let mut analyzer = JoinOrderAnalyzer::new();
592 analyzer.register_tables(vec!["t1".to_string(), "t2".to_string(), "t3".to_string()]);
593
594 let dummy_pred = Expression::Literal(vibesql_types::SqlValue::Integer(5));
596
597 if let Some(table_info) = analyzer.tables.get_mut("t1") {
599 table_info.local_predicates.push(dummy_pred.clone());
600 table_info.local_selectivity = 0.1;
601 }
602
603 if let Some(table_info) = analyzer.tables.get_mut("t2") {
605 table_info.local_predicates.push(dummy_pred.clone());
606 table_info.local_selectivity = 0.5;
607 }
608
609 let selective = analyzer.find_most_selective_tables();
610 assert_eq!(selective[0], "t1"); }
612
613 #[test]
614 fn test_join_condition_lookup() {
615 let mut analyzer = JoinOrderAnalyzer::new();
616 analyzer.register_tables(vec!["t1".to_string(), "t2".to_string()]);
617
618 analyzer.edges.push(JoinEdge {
619 left_table: "t1".to_string(),
620 left_column: "id".to_string(),
621 right_table: "t2".to_string(),
622 right_column: "id".to_string(),
623 join_type: vibesql_ast::JoinType::Inner,
624 });
625
626 let condition = analyzer.get_join_condition("t1", "t2");
627 assert!(condition.is_some());
628 assert_eq!(condition.unwrap(), ("id".to_string(), "id".to_string()));
629 }
630
631 #[test]
632 fn test_case_insensitive_tables() {
633 let mut analyzer = JoinOrderAnalyzer::new();
634 analyzer.register_tables(vec!["T1".to_string(), "T2".to_string()]);
635
636 analyzer.edges.push(JoinEdge {
637 left_table: "t1".to_string(),
638 left_column: "id".to_string(),
639 right_table: "t2".to_string(),
640 right_column: "id".to_string(),
641 join_type: vibesql_ast::JoinType::Inner,
642 });
643
644 let condition = analyzer.get_join_condition("T1", "T2");
646 assert!(condition.is_some());
647 }
648}