1use std::{
32 cmp::Ordering,
33 collections::{HashMap, HashSet},
34};
35
36use vibesql_ast::{BinaryOperator, Expression};
37
38#[derive(Debug, Clone)]
40struct TableInfo {
41 name: String,
42 local_predicates: Vec<Expression>, local_selectivity: f64, }
45
46#[derive(Debug, Clone, PartialEq)]
48pub struct JoinEdge {
49 pub left_table: String,
51 pub left_column: String,
53 pub right_table: String,
55 pub right_column: String,
57 pub join_type: vibesql_ast::JoinType,
59}
60
61impl JoinEdge {
62 pub fn involves_table(&self, table: &str) -> bool {
64 self.left_table.eq_ignore_ascii_case(table) || self.right_table.eq_ignore_ascii_case(table)
65 }
66
67 pub fn other_table(&self, table: &str) -> Option<String> {
69 if self.left_table.eq_ignore_ascii_case(table) {
70 Some(self.right_table.clone())
71 } else if self.right_table.eq_ignore_ascii_case(table) {
72 Some(self.left_table.clone())
73 } else {
74 None
75 }
76 }
77}
78
79#[derive(Debug, Clone)]
81pub struct Selectivity {
82 pub factor: f64,
84 pub predicate_type: PredicateType,
86}
87
88#[derive(Debug, Clone, PartialEq, Eq)]
90pub enum PredicateType {
91 Local,
93 Equijoin,
95 Complex,
97}
98
99#[derive(Debug, Clone)]
101pub struct JoinOrderAnalyzer {
102 tables: HashMap<String, TableInfo>,
104 edges: Vec<JoinEdge>,
106 #[allow(dead_code)]
108 selectivity: HashMap<String, Selectivity>,
109 column_to_table: HashMap<String, String>,
111}
112
113impl Default for JoinOrderAnalyzer {
114 fn default() -> Self {
115 Self::new()
116 }
117}
118
119impl JoinOrderAnalyzer {
120 pub fn new() -> Self {
122 Self {
123 tables: HashMap::new(),
124 edges: Vec::new(),
125 selectivity: HashMap::new(),
126 column_to_table: HashMap::new(),
127 }
128 }
129
130 pub fn with_column_map(column_to_table: HashMap<String, String>) -> Self {
132 Self {
133 tables: HashMap::new(),
134 edges: Vec::new(),
135 selectivity: HashMap::new(),
136 column_to_table,
137 }
138 }
139
140 pub fn set_column_map(&mut self, column_to_table: HashMap<String, String>) {
142 self.column_to_table = column_to_table;
143 }
144
145 pub fn register_tables(&mut self, table_names: Vec<String>) {
147 for name in table_names {
148 self.tables.insert(
149 name.to_lowercase(),
150 TableInfo {
151 name: name.to_lowercase(),
152 local_predicates: Vec::new(),
153 local_selectivity: 1.0,
154 },
155 );
156 }
157 }
158
159 pub fn analyze_predicate(&mut self, expr: &Expression, tables: &HashSet<String>) {
161 self.analyze_predicate_with_type(expr, tables, vibesql_ast::JoinType::Inner);
162 }
163
164 pub fn analyze_predicate_with_type(&mut self, expr: &Expression, tables: &HashSet<String>, join_type: vibesql_ast::JoinType) {
166 match expr {
167 Expression::BinaryOp { op: BinaryOperator::And, left, right } => {
169 if std::env::var("JOIN_REORDER_VERBOSE").is_ok() {
170 eprintln!("[ANALYZER] Decomposing AND expression");
171 }
172 self.analyze_predicate_with_type(left, tables, join_type.clone());
173 self.analyze_predicate_with_type(right, tables, join_type);
174 }
175 Expression::BinaryOp { op: BinaryOperator::Or, .. } => {
178 if std::env::var("JOIN_REORDER_VERBOSE").is_ok() {
179 eprintln!("[ANALYZER] Analyzing OR expression for common join conditions");
180 }
181
182 let mut branches = Vec::new();
184 self.collect_or_branches(expr, &mut branches);
185
186 if std::env::var("JOIN_REORDER_VERBOSE").is_ok() {
187 eprintln!("[ANALYZER] Found {} OR branches", branches.len());
188 }
189
190 let mut branch_edges: Vec<Vec<JoinEdge>> = Vec::new();
192 for branch in &branches {
193 let mut branch_analyzer = JoinOrderAnalyzer::new();
194 let table_vec: Vec<String> = tables.iter().cloned().collect();
195 branch_analyzer.register_tables(table_vec);
196 branch_analyzer.analyze_predicate(branch, tables);
197 branch_edges.push(branch_analyzer.edges().to_vec());
198 }
199
200 if !branch_edges.is_empty() {
202 let common_edges = self.find_common_edges(&branch_edges);
203
204 if std::env::var("JOIN_REORDER_VERBOSE").is_ok() {
205 eprintln!("[ANALYZER] Found {} common join edges across all OR branches", common_edges.len());
206 }
207
208 for edge in common_edges {
210 if std::env::var("JOIN_REORDER_VERBOSE").is_ok() {
211 eprintln!("[ANALYZER] Added common edge from OR: {}.{} = {}.{}",
212 edge.left_table, edge.left_column, edge.right_table, edge.right_column);
213 }
214 self.edges.push(edge);
215 }
216 }
217 }
218 Expression::BinaryOp { op: BinaryOperator::Equal, left, right } => {
220 let (left_table, left_col) = self.extract_column_ref(left, tables);
221 let (right_table, right_col) = self.extract_column_ref(right, tables);
222
223 if std::env::var("JOIN_REORDER_VERBOSE").is_ok() {
224 eprintln!("[ANALYZER] analyze_predicate: left_table={:?}, right_table={:?}, left_col={:?}, right_col={:?}",
225 left_table, right_table, left_col, right_col);
226 }
227
228 match (left_table, right_table, left_col, right_col) {
229 (Some(lt), Some(rt), Some(lc), Some(rc)) if lt != rt => {
231 let edge = JoinEdge {
232 left_table: lt.to_lowercase(),
233 left_column: lc.clone(),
234 right_table: rt.to_lowercase(),
235 right_column: rc.clone(),
236 join_type: join_type.clone(),
237 };
238 if std::env::var("JOIN_REORDER_VERBOSE").is_ok() {
239 eprintln!("[ANALYZER] Added edge: {}.{} = {}.{} (join_type: {:?})", lt, lc, rt, rc, join_type);
240 }
241 self.edges.push(edge);
242 }
243 (Some(table), None, Some(_col), _) => {
245 if let Some(table_info) = self.tables.get_mut(&table.to_lowercase()) {
246 table_info.local_predicates.push(expr.clone());
247 table_info.local_selectivity *= 0.1;
249 }
250 }
251 _ => {
252 if std::env::var("JOIN_REORDER_VERBOSE").is_ok() {
253 eprintln!("[ANALYZER] Skipped predicate (no match)");
254 }
255 }
256 }
257 }
258 _ => {
260 }
262 }
263 }
264
265 #[allow(clippy::only_used_in_recursion)]
268 fn collect_or_branches(&self, expr: &Expression, branches: &mut Vec<Expression>) {
269 match expr {
270 Expression::BinaryOp { op: BinaryOperator::Or, left, right } => {
271 self.collect_or_branches(left, branches);
273 self.collect_or_branches(right, branches);
274 }
275 _ => {
276 branches.push(expr.clone());
278 }
279 }
280 }
281
282 fn find_common_edges(&self, branch_edges: &[Vec<JoinEdge>]) -> Vec<JoinEdge> {
285 if branch_edges.is_empty() {
286 return Vec::new();
287 }
288
289 let mut common_edges = Vec::new();
291 let first_branch = &branch_edges[0];
292
293 for edge in first_branch {
294 let appears_in_all = branch_edges[1..].iter().all(|branch| {
296 branch.iter().any(|e| self.edges_match(e, edge))
297 });
298
299 if appears_in_all {
300 common_edges.push(edge.clone());
301 }
302 }
303
304 common_edges
305 }
306
307 fn edges_match(&self, e1: &JoinEdge, e2: &JoinEdge) -> bool {
310 let direct = e1.left_table.eq_ignore_ascii_case(&e2.left_table)
312 && e1.left_column.eq_ignore_ascii_case(&e2.left_column)
313 && e1.right_table.eq_ignore_ascii_case(&e2.right_table)
314 && e1.right_column.eq_ignore_ascii_case(&e2.right_column);
315
316 let reverse = e1.left_table.eq_ignore_ascii_case(&e2.right_table)
318 && e1.left_column.eq_ignore_ascii_case(&e2.right_column)
319 && e1.right_table.eq_ignore_ascii_case(&e2.left_table)
320 && e1.right_column.eq_ignore_ascii_case(&e2.left_column);
321
322 direct || reverse
323 }
324
325 fn extract_column_ref(
329 &self,
330 expr: &Expression,
331 tables: &HashSet<String>,
332 ) -> (Option<String>, Option<String>) {
333 match expr {
334 Expression::ColumnRef { table, column } => {
335 if let Some(t) = table {
337 return (Some(t.clone()), Some(column.clone()));
338 }
339
340 let inferred_table = self.infer_table_from_column(column, tables);
342 (inferred_table, Some(column.clone()))
343 }
344 Expression::Literal(_) => (None, None),
345 _ => (None, None),
346 }
347 }
348
349 fn infer_table_from_column(&self, column: &str, tables: &HashSet<String>) -> Option<String> {
354 if self.column_to_table.is_empty() {
356 if std::env::var("JOIN_REORDER_VERBOSE").is_ok() {
357 eprintln!("[ANALYZER] Warning: No column-to-table map available for column {}", column);
358 }
359 return None;
360 }
361
362 let col_lower = column.to_lowercase();
363 if let Some(table) = self.column_to_table.get(&col_lower) {
364 if std::env::var("JOIN_REORDER_VERBOSE").is_ok() {
365 eprintln!("[ANALYZER] Schema lookup: {} -> {}", column, table);
366 }
367 if tables.contains(table) {
369 return Some(table.clone());
370 }
371 for t in tables {
373 if t.eq_ignore_ascii_case(table) {
374 return Some(t.clone());
375 }
376 }
377 if std::env::var("JOIN_REORDER_VERBOSE").is_ok() {
378 eprintln!("[ANALYZER] Warning: Table {} not in tables set {:?}", table, tables);
379 }
380 } else if std::env::var("JOIN_REORDER_VERBOSE").is_ok() {
381 eprintln!("[ANALYZER] Warning: Column {} not found in schema map (available: {:?})",
382 col_lower, self.column_to_table.keys().take(10).collect::<Vec<_>>());
383 }
384
385 None
386 }
387
388 pub fn find_most_selective_tables(&self) -> Vec<String> {
390 let mut tables: Vec<_> =
391 self.tables.values().filter(|t| !t.local_predicates.is_empty()).collect();
392
393 tables.sort_by(|a, b| {
395 a.local_selectivity.partial_cmp(&b.local_selectivity).unwrap_or(Ordering::Equal)
396 });
397
398 tables.iter().map(|t| t.name.clone()).collect()
399 }
400
401 pub fn build_join_chain(&self, seed_table: &str) -> Vec<String> {
404 let mut chain = vec![seed_table.to_lowercase()];
405 let mut visited = HashSet::new();
406 visited.insert(seed_table.to_lowercase());
407
408 while chain.len() < self.tables.len() {
410 let current_table = chain[chain.len() - 1].clone();
411
412 let mut next_table: Option<String> = None;
414 for edge in &self.edges {
415 if edge.left_table == current_table && !visited.contains(&edge.right_table) {
416 next_table = Some(edge.right_table.clone());
417 break;
418 } else if edge.right_table == current_table && !visited.contains(&edge.left_table) {
419 next_table = Some(edge.left_table.clone());
420 break;
421 }
422 }
423
424 if next_table.is_none() {
426 for table in self.tables.keys() {
427 if !visited.contains(table) {
428 next_table = Some(table.clone());
429 break;
430 }
431 }
432 }
433
434 if let Some(table) = next_table {
435 chain.push(table.clone());
436 visited.insert(table);
437 } else {
438 break;
439 }
440 }
441
442 chain
443 }
444
445 pub fn find_optimal_order(&self) -> Vec<String> {
450 let selective_tables = self.find_most_selective_tables();
452
453 if let Some(seed) = selective_tables.first() {
455 self.build_join_chain(seed)
456 } else {
457 if let Some(table) = self.tables.keys().next() {
459 self.build_join_chain(table)
460 } else {
461 Vec::new()
462 }
463 }
464 }
465
466 pub fn get_join_condition(
468 &self,
469 left_table: &str,
470 right_table: &str,
471 ) -> Option<(String, String)> {
472 let left_lower = left_table.to_lowercase();
473 let right_lower = right_table.to_lowercase();
474
475 for edge in &self.edges {
476 if (edge.left_table == left_lower && edge.right_table == right_lower)
477 || (edge.left_table == right_lower && edge.right_table == left_lower)
478 {
479 return Some((edge.left_column.clone(), edge.right_column.clone()));
480 }
481 }
482 None
483 }
484
485 pub fn edges(&self) -> &[JoinEdge] {
487 &self.edges
488 }
489
490 pub fn tables(&self) -> std::collections::BTreeSet<String> {
492 self.tables.keys().cloned().collect()
493 }
494
495 #[cfg(test)]
497 pub fn add_edge(&mut self, edge: JoinEdge) {
498 self.edges.push(edge);
499 }
500}
501
502#[cfg(test)]
503mod tests {
504 use super::*;
505
506 #[test]
507 fn test_join_edge_involvement() {
508 let edge = JoinEdge {
509 left_table: "t1".to_string(),
510 left_column: "a".to_string(),
511 right_table: "t2".to_string(),
512 right_column: "b".to_string(),
513 join_type: vibesql_ast::JoinType::Inner,
514 };
515
516 assert!(edge.involves_table("t1"));
517 assert!(edge.involves_table("t2"));
518 assert!(!edge.involves_table("t3"));
519 }
520
521 #[test]
522 fn test_join_edge_other_table() {
523 let edge = JoinEdge {
524 left_table: "t1".to_string(),
525 left_column: "a".to_string(),
526 right_table: "t2".to_string(),
527 right_column: "b".to_string(),
528 join_type: vibesql_ast::JoinType::Inner,
529 };
530
531 assert_eq!(edge.other_table("t1"), Some("t2".to_string()));
532 assert_eq!(edge.other_table("t2"), Some("t1".to_string()));
533 assert_eq!(edge.other_table("t3"), None);
534 }
535
536 #[test]
537 fn test_basic_chain_detection() {
538 let mut analyzer = JoinOrderAnalyzer::new();
539 analyzer.register_tables(vec!["t1".to_string(), "t2".to_string(), "t3".to_string()]);
540
541 analyzer.edges.push(JoinEdge {
543 left_table: "t1".to_string(),
544 left_column: "id".to_string(),
545 right_table: "t2".to_string(),
546 right_column: "id".to_string(),
547 join_type: vibesql_ast::JoinType::Inner,
548 });
549 analyzer.edges.push(JoinEdge {
550 left_table: "t2".to_string(),
551 left_column: "id".to_string(),
552 right_table: "t3".to_string(),
553 right_column: "id".to_string(),
554 join_type: vibesql_ast::JoinType::Inner,
555 });
556
557 let chain = analyzer.build_join_chain("t1");
558 assert_eq!(chain.len(), 3);
559 assert_eq!(chain[0], "t1");
560 }
562
563 #[test]
564 fn test_most_selective_tables() {
565 let mut analyzer = JoinOrderAnalyzer::new();
566 analyzer.register_tables(vec!["t1".to_string(), "t2".to_string(), "t3".to_string()]);
567
568 let dummy_pred = Expression::Literal(vibesql_types::SqlValue::Integer(5));
570
571 if let Some(table_info) = analyzer.tables.get_mut("t1") {
573 table_info.local_predicates.push(dummy_pred.clone());
574 table_info.local_selectivity = 0.1;
575 }
576
577 if let Some(table_info) = analyzer.tables.get_mut("t2") {
579 table_info.local_predicates.push(dummy_pred.clone());
580 table_info.local_selectivity = 0.5;
581 }
582
583 let selective = analyzer.find_most_selective_tables();
584 assert_eq!(selective[0], "t1"); }
586
587 #[test]
588 fn test_join_condition_lookup() {
589 let mut analyzer = JoinOrderAnalyzer::new();
590 analyzer.register_tables(vec!["t1".to_string(), "t2".to_string()]);
591
592 analyzer.edges.push(JoinEdge {
593 left_table: "t1".to_string(),
594 left_column: "id".to_string(),
595 right_table: "t2".to_string(),
596 right_column: "id".to_string(),
597 join_type: vibesql_ast::JoinType::Inner,
598 });
599
600 let condition = analyzer.get_join_condition("t1", "t2");
601 assert!(condition.is_some());
602 assert_eq!(condition.unwrap(), ("id".to_string(), "id".to_string()));
603 }
604
605 #[test]
606 fn test_case_insensitive_tables() {
607 let mut analyzer = JoinOrderAnalyzer::new();
608 analyzer.register_tables(vec!["T1".to_string(), "T2".to_string()]);
609
610 analyzer.edges.push(JoinEdge {
611 left_table: "t1".to_string(),
612 left_column: "id".to_string(),
613 right_table: "t2".to_string(),
614 right_column: "id".to_string(),
615 join_type: vibesql_ast::JoinType::Inner,
616 });
617
618 let condition = analyzer.get_join_condition("T1", "T2");
620 assert!(condition.is_some());
621 }
622}