1use std::{
32 cmp::Ordering,
33 collections::{HashMap, HashSet},
34};
35
36use vibesql_ast::{BinaryOperator, Expression};
37
38#[derive(Debug, Clone)]
40struct TableInfo {
41 name: String,
42 local_predicates: Vec<Expression>, local_selectivity: f64, }
45
46#[derive(Debug, Clone, PartialEq)]
48pub struct JoinEdge {
49 pub left_table: String,
51 pub left_column: String,
53 pub right_table: String,
55 pub right_column: String,
57 pub join_type: vibesql_ast::JoinType,
59}
60
61impl JoinEdge {
62 pub fn involves_table(&self, table: &str) -> bool {
64 self.left_table.eq_ignore_ascii_case(table) || self.right_table.eq_ignore_ascii_case(table)
65 }
66
67 pub fn other_table(&self, table: &str) -> Option<String> {
69 if self.left_table.eq_ignore_ascii_case(table) {
70 Some(self.right_table.clone())
71 } else if self.right_table.eq_ignore_ascii_case(table) {
72 Some(self.left_table.clone())
73 } else {
74 None
75 }
76 }
77}
78
79#[derive(Debug, Clone)]
81pub struct Selectivity {
82 pub factor: f64,
84 pub predicate_type: PredicateType,
86}
87
88#[derive(Debug, Clone, PartialEq, Eq)]
90pub enum PredicateType {
91 Local,
93 Equijoin,
95 Complex,
97}
98
99#[derive(Debug, Clone)]
101pub struct JoinOrderAnalyzer {
102 tables: HashMap<String, TableInfo>,
104 edges: Vec<JoinEdge>,
106 #[allow(dead_code)]
108 selectivity: HashMap<String, Selectivity>,
109 column_to_table: HashMap<String, String>,
111}
112
113impl Default for JoinOrderAnalyzer {
114 fn default() -> Self {
115 Self::new()
116 }
117}
118
119impl JoinOrderAnalyzer {
120 pub fn new() -> Self {
122 Self {
123 tables: HashMap::new(),
124 edges: Vec::new(),
125 selectivity: HashMap::new(),
126 column_to_table: HashMap::new(),
127 }
128 }
129
130 pub fn with_column_map(column_to_table: HashMap<String, String>) -> Self {
132 Self {
133 tables: HashMap::new(),
134 edges: Vec::new(),
135 selectivity: HashMap::new(),
136 column_to_table,
137 }
138 }
139
140 pub fn set_column_map(&mut self, column_to_table: HashMap<String, String>) {
142 self.column_to_table = column_to_table;
143 }
144
145 pub fn register_tables(&mut self, table_names: Vec<String>) {
147 for name in table_names {
148 self.tables.insert(
149 name.to_lowercase(),
150 TableInfo {
151 name: name.to_lowercase(),
152 local_predicates: Vec::new(),
153 local_selectivity: 1.0,
154 },
155 );
156 }
157 }
158
159 pub fn analyze_predicate(&mut self, expr: &Expression, tables: &HashSet<String>) {
161 self.analyze_predicate_with_type(expr, tables, vibesql_ast::JoinType::Inner);
162 }
163
164 pub fn analyze_predicate_with_type(
166 &mut self,
167 expr: &Expression,
168 tables: &HashSet<String>,
169 join_type: vibesql_ast::JoinType,
170 ) {
171 match expr {
172 Expression::BinaryOp { op: BinaryOperator::And, left, right } => {
174 if std::env::var("JOIN_REORDER_VERBOSE").is_ok() {
175 eprintln!("[ANALYZER] Decomposing AND expression");
176 }
177 self.analyze_predicate_with_type(left, tables, join_type.clone());
178 self.analyze_predicate_with_type(right, tables, join_type);
179 }
180 Expression::BinaryOp { op: BinaryOperator::Or, .. } => {
183 if std::env::var("JOIN_REORDER_VERBOSE").is_ok() {
184 eprintln!("[ANALYZER] Analyzing OR expression for common join conditions");
185 }
186
187 let mut branches = Vec::new();
189 self.collect_or_branches(expr, &mut branches);
190
191 if std::env::var("JOIN_REORDER_VERBOSE").is_ok() {
192 eprintln!("[ANALYZER] Found {} OR branches", branches.len());
193 }
194
195 let mut branch_edges: Vec<Vec<JoinEdge>> = Vec::new();
197 for branch in &branches {
198 let mut branch_analyzer = JoinOrderAnalyzer::new();
199 let table_vec: Vec<String> = tables.iter().cloned().collect();
200 branch_analyzer.register_tables(table_vec);
201 branch_analyzer.analyze_predicate(branch, tables);
202 branch_edges.push(branch_analyzer.edges().to_vec());
203 }
204
205 if !branch_edges.is_empty() {
207 let common_edges = self.find_common_edges(&branch_edges);
208
209 if std::env::var("JOIN_REORDER_VERBOSE").is_ok() {
210 eprintln!(
211 "[ANALYZER] Found {} common join edges across all OR branches",
212 common_edges.len()
213 );
214 }
215
216 for edge in common_edges {
218 if std::env::var("JOIN_REORDER_VERBOSE").is_ok() {
219 eprintln!(
220 "[ANALYZER] Added common edge from OR: {}.{} = {}.{}",
221 edge.left_table,
222 edge.left_column,
223 edge.right_table,
224 edge.right_column
225 );
226 }
227 self.edges.push(edge);
228 }
229 }
230 }
231 Expression::BinaryOp { op: BinaryOperator::Equal, left, right } => {
233 let (left_table, left_col) = self.extract_column_ref(left, tables);
234 let (right_table, right_col) = self.extract_column_ref(right, tables);
235
236 if std::env::var("JOIN_REORDER_VERBOSE").is_ok() {
237 eprintln!("[ANALYZER] analyze_predicate: left_table={:?}, right_table={:?}, left_col={:?}, right_col={:?}",
238 left_table, right_table, left_col, right_col);
239 }
240
241 match (left_table, right_table, left_col, right_col) {
242 (Some(lt), Some(rt), Some(lc), Some(rc)) if lt != rt => {
244 let edge = JoinEdge {
245 left_table: lt.to_lowercase(),
246 left_column: lc.clone(),
247 right_table: rt.to_lowercase(),
248 right_column: rc.clone(),
249 join_type: join_type.clone(),
250 };
251 if std::env::var("JOIN_REORDER_VERBOSE").is_ok() {
252 eprintln!(
253 "[ANALYZER] Added edge: {}.{} = {}.{} (join_type: {:?})",
254 lt, lc, rt, rc, join_type
255 );
256 }
257 self.edges.push(edge);
258 }
259 (Some(table), None, Some(_col), _) => {
261 if let Some(table_info) = self.tables.get_mut(&table.to_lowercase()) {
262 table_info.local_predicates.push(expr.clone());
263 table_info.local_selectivity *= 0.1;
265 }
266 }
267 _ => {
268 if std::env::var("JOIN_REORDER_VERBOSE").is_ok() {
269 eprintln!("[ANALYZER] Skipped predicate (no match)");
270 }
271 }
272 }
273 }
274 _ => {
276 }
278 }
279 }
280
281 #[allow(clippy::only_used_in_recursion)]
284 fn collect_or_branches(&self, expr: &Expression, branches: &mut Vec<Expression>) {
285 match expr {
286 Expression::BinaryOp { op: BinaryOperator::Or, left, right } => {
287 self.collect_or_branches(left, branches);
289 self.collect_or_branches(right, branches);
290 }
291 _ => {
292 branches.push(expr.clone());
294 }
295 }
296 }
297
298 fn find_common_edges(&self, branch_edges: &[Vec<JoinEdge>]) -> Vec<JoinEdge> {
301 if branch_edges.is_empty() {
302 return Vec::new();
303 }
304
305 let mut common_edges = Vec::new();
307 let first_branch = &branch_edges[0];
308
309 for edge in first_branch {
310 let appears_in_all = branch_edges[1..]
312 .iter()
313 .all(|branch| branch.iter().any(|e| self.edges_match(e, edge)));
314
315 if appears_in_all {
316 common_edges.push(edge.clone());
317 }
318 }
319
320 common_edges
321 }
322
323 fn edges_match(&self, e1: &JoinEdge, e2: &JoinEdge) -> bool {
326 let direct = e1.left_table.eq_ignore_ascii_case(&e2.left_table)
328 && e1.left_column.eq_ignore_ascii_case(&e2.left_column)
329 && e1.right_table.eq_ignore_ascii_case(&e2.right_table)
330 && e1.right_column.eq_ignore_ascii_case(&e2.right_column);
331
332 let reverse = e1.left_table.eq_ignore_ascii_case(&e2.right_table)
334 && e1.left_column.eq_ignore_ascii_case(&e2.right_column)
335 && e1.right_table.eq_ignore_ascii_case(&e2.left_table)
336 && e1.right_column.eq_ignore_ascii_case(&e2.left_column);
337
338 direct || reverse
339 }
340
341 fn extract_column_ref(
345 &self,
346 expr: &Expression,
347 tables: &HashSet<String>,
348 ) -> (Option<String>, Option<String>) {
349 match expr {
350 Expression::ColumnRef { table, column } => {
351 if let Some(t) = table {
353 return (Some(t.clone()), Some(column.clone()));
354 }
355
356 let inferred_table = self.infer_table_from_column(column, tables);
358 (inferred_table, Some(column.clone()))
359 }
360 Expression::Literal(_) => (None, None),
361 _ => (None, None),
362 }
363 }
364
365 fn infer_table_from_column(&self, column: &str, tables: &HashSet<String>) -> Option<String> {
370 if self.column_to_table.is_empty() {
372 if std::env::var("JOIN_REORDER_VERBOSE").is_ok() {
373 eprintln!(
374 "[ANALYZER] Warning: No column-to-table map available for column {}",
375 column
376 );
377 }
378 return None;
379 }
380
381 let col_lower = column.to_lowercase();
382 if let Some(table) = self.column_to_table.get(&col_lower) {
383 if std::env::var("JOIN_REORDER_VERBOSE").is_ok() {
384 eprintln!("[ANALYZER] Schema lookup: {} -> {}", column, table);
385 }
386 if tables.contains(table) {
388 return Some(table.clone());
389 }
390 for t in tables {
392 if t.eq_ignore_ascii_case(table) {
393 return Some(t.clone());
394 }
395 }
396 if std::env::var("JOIN_REORDER_VERBOSE").is_ok() {
397 eprintln!("[ANALYZER] Warning: Table {} not in tables set {:?}", table, tables);
398 }
399 } else if std::env::var("JOIN_REORDER_VERBOSE").is_ok() {
400 eprintln!(
401 "[ANALYZER] Warning: Column {} not found in schema map (available: {:?})",
402 col_lower,
403 self.column_to_table.keys().take(10).collect::<Vec<_>>()
404 );
405 }
406
407 None
408 }
409
410 pub fn find_most_selective_tables(&self) -> Vec<String> {
412 let mut tables: Vec<_> =
413 self.tables.values().filter(|t| !t.local_predicates.is_empty()).collect();
414
415 tables.sort_by(|a, b| {
417 a.local_selectivity.partial_cmp(&b.local_selectivity).unwrap_or(Ordering::Equal)
418 });
419
420 tables.iter().map(|t| t.name.clone()).collect()
421 }
422
423 pub fn build_join_chain(&self, seed_table: &str) -> Vec<String> {
426 let mut chain = vec![seed_table.to_lowercase()];
427 let mut visited = HashSet::new();
428 visited.insert(seed_table.to_lowercase());
429
430 while chain.len() < self.tables.len() {
432 let current_table = chain[chain.len() - 1].clone();
433
434 let mut next_table: Option<String> = None;
436 for edge in &self.edges {
437 if edge.left_table == current_table && !visited.contains(&edge.right_table) {
438 next_table = Some(edge.right_table.clone());
439 break;
440 } else if edge.right_table == current_table && !visited.contains(&edge.left_table) {
441 next_table = Some(edge.left_table.clone());
442 break;
443 }
444 }
445
446 if next_table.is_none() {
448 for table in self.tables.keys() {
449 if !visited.contains(table) {
450 next_table = Some(table.clone());
451 break;
452 }
453 }
454 }
455
456 if let Some(table) = next_table {
457 chain.push(table.clone());
458 visited.insert(table);
459 } else {
460 break;
461 }
462 }
463
464 chain
465 }
466
467 pub fn find_optimal_order(&self) -> Vec<String> {
472 let selective_tables = self.find_most_selective_tables();
474
475 if let Some(seed) = selective_tables.first() {
477 self.build_join_chain(seed)
478 } else {
479 if let Some(table) = self.tables.keys().next() {
481 self.build_join_chain(table)
482 } else {
483 Vec::new()
484 }
485 }
486 }
487
488 pub fn get_join_condition(
490 &self,
491 left_table: &str,
492 right_table: &str,
493 ) -> Option<(String, String)> {
494 let left_lower = left_table.to_lowercase();
495 let right_lower = right_table.to_lowercase();
496
497 for edge in &self.edges {
498 if (edge.left_table == left_lower && edge.right_table == right_lower)
499 || (edge.left_table == right_lower && edge.right_table == left_lower)
500 {
501 return Some((edge.left_column.clone(), edge.right_column.clone()));
502 }
503 }
504 None
505 }
506
507 pub fn edges(&self) -> &[JoinEdge] {
509 &self.edges
510 }
511
512 pub fn tables(&self) -> std::collections::BTreeSet<String> {
514 self.tables.keys().cloned().collect()
515 }
516
517 #[cfg(test)]
519 pub fn add_edge(&mut self, edge: JoinEdge) {
520 self.edges.push(edge);
521 }
522}
523
524#[cfg(test)]
525mod tests {
526 use super::*;
527
528 #[test]
529 fn test_join_edge_involvement() {
530 let edge = JoinEdge {
531 left_table: "t1".to_string(),
532 left_column: "a".to_string(),
533 right_table: "t2".to_string(),
534 right_column: "b".to_string(),
535 join_type: vibesql_ast::JoinType::Inner,
536 };
537
538 assert!(edge.involves_table("t1"));
539 assert!(edge.involves_table("t2"));
540 assert!(!edge.involves_table("t3"));
541 }
542
543 #[test]
544 fn test_join_edge_other_table() {
545 let edge = JoinEdge {
546 left_table: "t1".to_string(),
547 left_column: "a".to_string(),
548 right_table: "t2".to_string(),
549 right_column: "b".to_string(),
550 join_type: vibesql_ast::JoinType::Inner,
551 };
552
553 assert_eq!(edge.other_table("t1"), Some("t2".to_string()));
554 assert_eq!(edge.other_table("t2"), Some("t1".to_string()));
555 assert_eq!(edge.other_table("t3"), None);
556 }
557
558 #[test]
559 fn test_basic_chain_detection() {
560 let mut analyzer = JoinOrderAnalyzer::new();
561 analyzer.register_tables(vec!["t1".to_string(), "t2".to_string(), "t3".to_string()]);
562
563 analyzer.edges.push(JoinEdge {
565 left_table: "t1".to_string(),
566 left_column: "id".to_string(),
567 right_table: "t2".to_string(),
568 right_column: "id".to_string(),
569 join_type: vibesql_ast::JoinType::Inner,
570 });
571 analyzer.edges.push(JoinEdge {
572 left_table: "t2".to_string(),
573 left_column: "id".to_string(),
574 right_table: "t3".to_string(),
575 right_column: "id".to_string(),
576 join_type: vibesql_ast::JoinType::Inner,
577 });
578
579 let chain = analyzer.build_join_chain("t1");
580 assert_eq!(chain.len(), 3);
581 assert_eq!(chain[0], "t1");
582 }
584
585 #[test]
586 fn test_most_selective_tables() {
587 let mut analyzer = JoinOrderAnalyzer::new();
588 analyzer.register_tables(vec!["t1".to_string(), "t2".to_string(), "t3".to_string()]);
589
590 let dummy_pred = Expression::Literal(vibesql_types::SqlValue::Integer(5));
592
593 if let Some(table_info) = analyzer.tables.get_mut("t1") {
595 table_info.local_predicates.push(dummy_pred.clone());
596 table_info.local_selectivity = 0.1;
597 }
598
599 if let Some(table_info) = analyzer.tables.get_mut("t2") {
601 table_info.local_predicates.push(dummy_pred.clone());
602 table_info.local_selectivity = 0.5;
603 }
604
605 let selective = analyzer.find_most_selective_tables();
606 assert_eq!(selective[0], "t1"); }
608
609 #[test]
610 fn test_join_condition_lookup() {
611 let mut analyzer = JoinOrderAnalyzer::new();
612 analyzer.register_tables(vec!["t1".to_string(), "t2".to_string()]);
613
614 analyzer.edges.push(JoinEdge {
615 left_table: "t1".to_string(),
616 left_column: "id".to_string(),
617 right_table: "t2".to_string(),
618 right_column: "id".to_string(),
619 join_type: vibesql_ast::JoinType::Inner,
620 });
621
622 let condition = analyzer.get_join_condition("t1", "t2");
623 assert!(condition.is_some());
624 assert_eq!(condition.unwrap(), ("id".to_string(), "id".to_string()));
625 }
626
627 #[test]
628 fn test_case_insensitive_tables() {
629 let mut analyzer = JoinOrderAnalyzer::new();
630 analyzer.register_tables(vec!["T1".to_string(), "T2".to_string()]);
631
632 analyzer.edges.push(JoinEdge {
633 left_table: "t1".to_string(),
634 left_column: "id".to_string(),
635 right_table: "t2".to_string(),
636 right_column: "id".to_string(),
637 join_type: vibesql_ast::JoinType::Inner,
638 });
639
640 let condition = analyzer.get_join_condition("T1", "T2");
642 assert!(condition.is_some());
643 }
644}