1use anyhow::{anyhow, Result};
4use std::collections::HashMap;
5use std::sync::Arc;
6use tracing::{debug, info};
7
8use crate::data::datatable::{DataColumn, DataRow, DataTable, DataValue};
9use crate::sql::parser::ast::{JoinClause, JoinOperator, JoinType};
10
11pub struct HashJoinExecutor {
13 case_insensitive: bool,
14}
15
16impl HashJoinExecutor {
17 pub fn new(case_insensitive: bool) -> Self {
18 Self { case_insensitive }
19 }
20
21 pub fn execute_join(
23 &self,
24 left_table: Arc<DataTable>,
25 join_clause: &JoinClause,
26 right_table: Arc<DataTable>,
27 ) -> Result<DataTable> {
28 info!(
29 "Executing {:?} JOIN: {} rows x {} rows",
30 join_clause.join_type,
31 left_table.row_count(),
32 right_table.row_count()
33 );
34
35 let (left_col_name, right_col_name) = self.parse_join_columns(join_clause)?;
37
38 let (left_col_idx, right_col_idx) =
42 self.resolve_join_columns(&left_table, &right_table, &left_col_name, &right_col_name)?;
43
44 let use_hash_join = join_clause.condition.operator == JoinOperator::Equal;
46
47 match join_clause.join_type {
49 JoinType::Inner => {
50 if use_hash_join {
51 self.hash_join_inner(
52 left_table,
53 right_table,
54 left_col_idx,
55 right_col_idx,
56 &left_col_name,
57 &right_col_name,
58 )
59 } else {
60 self.nested_loop_join_inner(
61 left_table,
62 right_table,
63 left_col_idx,
64 right_col_idx,
65 &join_clause.condition.operator,
66 )
67 }
68 }
69 JoinType::Left => {
70 if use_hash_join {
71 self.hash_join_left(
72 left_table,
73 right_table,
74 left_col_idx,
75 right_col_idx,
76 &left_col_name,
77 &right_col_name,
78 )
79 } else {
80 self.nested_loop_join_left(
81 left_table,
82 right_table,
83 left_col_idx,
84 right_col_idx,
85 &join_clause.condition.operator,
86 )
87 }
88 }
89 JoinType::Right => {
90 if use_hash_join {
91 self.hash_join_left(
93 right_table,
94 left_table,
95 right_col_idx,
96 left_col_idx,
97 &right_col_name,
98 &left_col_name,
99 )
100 } else {
101 self.nested_loop_join_left(
103 right_table,
104 left_table,
105 right_col_idx,
106 left_col_idx,
107 &self.reverse_operator(&join_clause.condition.operator),
108 )
109 }
110 }
111 JoinType::Cross => self.cross_join(left_table, right_table),
112 JoinType::Full => {
113 return Err(anyhow!("FULL OUTER JOIN not yet implemented"));
114 }
115 }
116 }
117
118 fn parse_join_columns(&self, join_clause: &JoinClause) -> Result<(String, String)> {
120 Ok((
121 join_clause.condition.left_column.clone(),
122 join_clause.condition.right_column.clone(),
123 ))
124 }
125
126 fn resolve_join_columns(
128 &self,
129 left_table: &DataTable,
130 right_table: &DataTable,
131 left_col_name: &str,
132 right_col_name: &str,
133 ) -> Result<(usize, usize)> {
134 let left_col_idx = if let Ok(idx) = self.find_column_index(left_table, left_col_name) {
136 idx
137 } else if let Ok(idx) = self.find_column_index(right_table, left_col_name) {
138 return Err(anyhow!(
141 "Column '{}' found in right table but specified as left operand. \
142 Please rewrite the condition with columns in correct positions.",
143 left_col_name
144 ));
145 } else {
146 return Err(anyhow!(
147 "Column '{}' not found in either table",
148 left_col_name
149 ));
150 };
151
152 let right_col_idx = if let Ok(idx) = self.find_column_index(right_table, right_col_name) {
154 idx
155 } else if let Ok(idx) = self.find_column_index(left_table, right_col_name) {
156 return Err(anyhow!(
159 "Column '{}' found in left table but specified as right operand. \
160 Please rewrite the condition with columns in correct positions.",
161 right_col_name
162 ));
163 } else {
164 return Err(anyhow!(
165 "Column '{}' not found in either table",
166 right_col_name
167 ));
168 };
169
170 Ok((left_col_idx, right_col_idx))
171 }
172
173 fn find_column_index(&self, table: &DataTable, col_name: &str) -> Result<usize> {
175 let col_name = if let Some(dot_pos) = col_name.rfind('.') {
177 &col_name[dot_pos + 1..]
178 } else {
179 col_name
180 };
181
182 debug!(
183 "Looking for column '{}' in table with columns: {:?}",
184 col_name,
185 table.column_names()
186 );
187
188 table
189 .columns
190 .iter()
191 .position(|col| {
192 if self.case_insensitive {
193 col.name.to_lowercase() == col_name.to_lowercase()
194 } else {
195 col.name == col_name
196 }
197 })
198 .ok_or_else(|| anyhow!("Column '{}' not found in table", col_name))
199 }
200
201 fn hash_join_inner(
203 &self,
204 left_table: Arc<DataTable>,
205 right_table: Arc<DataTable>,
206 left_col_idx: usize,
207 right_col_idx: usize,
208 _left_col_name: &str,
209 _right_col_name: &str,
210 ) -> Result<DataTable> {
211 let start = std::time::Instant::now();
212
213 let (build_table, probe_table, build_col_idx, probe_col_idx, build_is_left) =
215 if left_table.row_count() <= right_table.row_count() {
216 (
217 left_table.clone(),
218 right_table.clone(),
219 left_col_idx,
220 right_col_idx,
221 true,
222 )
223 } else {
224 (
225 right_table.clone(),
226 left_table.clone(),
227 right_col_idx,
228 left_col_idx,
229 false,
230 )
231 };
232
233 debug!(
234 "Building hash index on {} table ({} rows)",
235 if build_is_left { "left" } else { "right" },
236 build_table.row_count()
237 );
238
239 let mut hash_index: HashMap<DataValue, Vec<usize>> = HashMap::new();
241 for (row_idx, row) in build_table.rows.iter().enumerate() {
242 let key = row.values[build_col_idx].clone();
243 hash_index.entry(key).or_default().push(row_idx);
244 }
245
246 debug!(
247 "Hash index built with {} unique keys in {:?}",
248 hash_index.len(),
249 start.elapsed()
250 );
251
252 let mut result = DataTable::new("joined");
254
255 for col in &left_table.columns {
257 result.add_column(DataColumn {
258 name: col.name.clone(),
259 data_type: col.data_type.clone(),
260 nullable: col.nullable,
261 unique_values: col.unique_values,
262 null_count: col.null_count,
263 metadata: col.metadata.clone(),
264 qualified_name: col.qualified_name.clone(), source_table: col.source_table.clone(), });
267 }
268
269 for col in &right_table.columns {
271 if !left_table
273 .columns
274 .iter()
275 .any(|left_col| left_col.name == col.name)
276 {
277 result.add_column(DataColumn {
278 name: col.name.clone(),
279 data_type: col.data_type.clone(),
280 nullable: col.nullable,
281 unique_values: col.unique_values,
282 null_count: col.null_count,
283 metadata: col.metadata.clone(),
284 qualified_name: col.qualified_name.clone(), source_table: col.source_table.clone(), });
287 } else {
288 result.add_column(DataColumn {
290 name: format!("{}_right", col.name),
291 data_type: col.data_type.clone(),
292 nullable: col.nullable,
293 unique_values: col.unique_values,
294 null_count: col.null_count,
295 metadata: col.metadata.clone(),
296 qualified_name: col.qualified_name.clone(), source_table: col.source_table.clone(), });
299 }
300 }
301
302 debug!(
303 "Joined table will have {} columns: {:?}",
304 result.column_count(),
305 result.column_names()
306 );
307
308 let mut match_count = 0;
310 for probe_row in &probe_table.rows {
311 let probe_key = &probe_row.values[probe_col_idx];
312
313 if let Some(matching_indices) = hash_index.get(probe_key) {
314 for &build_idx in matching_indices {
315 let build_row = &build_table.rows[build_idx];
316
317 let mut joined_row = DataRow { values: Vec::new() };
319
320 if build_is_left {
321 joined_row.values.extend_from_slice(&build_row.values);
323 joined_row.values.extend_from_slice(&probe_row.values);
324 } else {
325 joined_row.values.extend_from_slice(&probe_row.values);
327 joined_row.values.extend_from_slice(&build_row.values);
328 }
329
330 result.add_row(joined_row);
331 match_count += 1;
332 }
333 }
334 }
335
336 let qualified_cols: Vec<String> = result
338 .columns
339 .iter()
340 .filter_map(|c| c.qualified_name.clone())
341 .collect();
342
343 info!(
344 "INNER JOIN complete: {} matches found in {:?}. Result has {} columns ({} qualified: {:?})",
345 match_count,
346 start.elapsed(),
347 result.columns.len(),
348 qualified_cols.len(),
349 qualified_cols
350 );
351
352 Ok(result)
353 }
354
355 fn hash_join_left(
357 &self,
358 left_table: Arc<DataTable>,
359 right_table: Arc<DataTable>,
360 left_col_idx: usize,
361 right_col_idx: usize,
362 _left_col_name: &str,
363 _right_col_name: &str,
364 ) -> Result<DataTable> {
365 let start = std::time::Instant::now();
366
367 debug!(
368 "Building hash index on right table ({} rows)",
369 right_table.row_count()
370 );
371
372 let mut hash_index: HashMap<DataValue, Vec<usize>> = HashMap::new();
374 for (row_idx, row) in right_table.rows.iter().enumerate() {
375 let key = row.values[right_col_idx].clone();
376 hash_index.entry(key).or_default().push(row_idx);
377 }
378
379 let mut result = DataTable::new("joined");
381
382 for col in &left_table.columns {
384 result.add_column(DataColumn {
385 name: col.name.clone(),
386 data_type: col.data_type.clone(),
387 nullable: col.nullable,
388 unique_values: col.unique_values,
389 null_count: col.null_count,
390 metadata: col.metadata.clone(),
391 qualified_name: col.qualified_name.clone(), source_table: col.source_table.clone(), });
394 }
395
396 for col in &right_table.columns {
398 if !left_table
400 .columns
401 .iter()
402 .any(|left_col| left_col.name == col.name)
403 {
404 result.add_column(DataColumn {
405 name: col.name.clone(),
406 data_type: col.data_type.clone(),
407 nullable: true, unique_values: col.unique_values,
409 null_count: col.null_count,
410 metadata: col.metadata.clone(),
411 qualified_name: col.qualified_name.clone(), source_table: col.source_table.clone(), });
414 } else {
415 result.add_column(DataColumn {
417 name: format!("{}_right", col.name),
418 data_type: col.data_type.clone(),
419 nullable: true, unique_values: col.unique_values,
421 null_count: col.null_count,
422 metadata: col.metadata.clone(),
423 qualified_name: col.qualified_name.clone(), source_table: col.source_table.clone(), });
426 }
427 }
428
429 debug!(
430 "LEFT JOIN table will have {} columns: {:?}",
431 result.column_count(),
432 result.column_names()
433 );
434
435 let mut match_count = 0;
437 let mut null_count = 0;
438
439 for left_row in &left_table.rows {
440 let left_key = &left_row.values[left_col_idx];
441
442 if let Some(matching_indices) = hash_index.get(left_key) {
443 for &right_idx in matching_indices {
445 let right_row = &right_table.rows[right_idx];
446
447 let mut joined_row = DataRow { values: Vec::new() };
448 joined_row.values.extend_from_slice(&left_row.values);
449 joined_row.values.extend_from_slice(&right_row.values);
450
451 result.add_row(joined_row);
452 match_count += 1;
453 }
454 } else {
455 let mut joined_row = DataRow { values: Vec::new() };
457 joined_row.values.extend_from_slice(&left_row.values);
458
459 for _ in 0..right_table.column_count() {
461 joined_row.values.push(DataValue::Null);
462 }
463
464 result.add_row(joined_row);
465 null_count += 1;
466 }
467 }
468
469 let qualified_cols: Vec<String> = result
471 .columns
472 .iter()
473 .filter_map(|c| c.qualified_name.clone())
474 .collect();
475
476 info!(
477 "LEFT JOIN complete: {} matches, {} nulls in {:?}. Result has {} columns ({} qualified: {:?})",
478 match_count,
479 null_count,
480 start.elapsed(),
481 result.columns.len(),
482 qualified_cols.len(),
483 qualified_cols
484 );
485
486 Ok(result)
487 }
488
489 fn cross_join(
491 &self,
492 left_table: Arc<DataTable>,
493 right_table: Arc<DataTable>,
494 ) -> Result<DataTable> {
495 let start = std::time::Instant::now();
496
497 let result_rows = left_table.row_count() * right_table.row_count();
499 if result_rows > 1_000_000 {
500 return Err(anyhow!(
501 "CROSS JOIN would produce {} rows, which exceeds the safety limit",
502 result_rows
503 ));
504 }
505
506 let mut result = DataTable::new("joined");
508
509 for col in &left_table.columns {
511 result.add_column(col.clone());
512 }
513 for col in &right_table.columns {
514 result.add_column(col.clone());
515 }
516
517 for left_row in &left_table.rows {
519 for right_row in &right_table.rows {
520 let mut joined_row = DataRow { values: Vec::new() };
521 joined_row.values.extend_from_slice(&left_row.values);
522 joined_row.values.extend_from_slice(&right_row.values);
523 result.add_row(joined_row);
524 }
525 }
526
527 info!(
528 "CROSS JOIN complete: {} rows in {:?}",
529 result.row_count(),
530 start.elapsed()
531 );
532
533 Ok(result)
534 }
535
536 fn qualify_column_name(
538 &self,
539 col_name: &str,
540 table_side: &str,
541 left_join_col: &str,
542 right_join_col: &str,
543 ) -> String {
544 let base_name = if let Some(dot_pos) = col_name.rfind('.') {
546 &col_name[dot_pos + 1..]
547 } else {
548 col_name
549 };
550
551 let left_base = if let Some(dot_pos) = left_join_col.rfind('.') {
552 &left_join_col[dot_pos + 1..]
553 } else {
554 left_join_col
555 };
556
557 let right_base = if let Some(dot_pos) = right_join_col.rfind('.') {
558 &right_join_col[dot_pos + 1..]
559 } else {
560 right_join_col
561 };
562
563 if base_name == left_base || base_name == right_base {
565 format!("{}_{}", table_side, base_name)
566 } else {
567 col_name.to_string()
568 }
569 }
570
571 fn reverse_operator(&self, op: &JoinOperator) -> JoinOperator {
573 match op {
574 JoinOperator::Equal => JoinOperator::Equal,
575 JoinOperator::NotEqual => JoinOperator::NotEqual,
576 JoinOperator::LessThan => JoinOperator::GreaterThan,
577 JoinOperator::GreaterThan => JoinOperator::LessThan,
578 JoinOperator::LessThanOrEqual => JoinOperator::GreaterThanOrEqual,
579 JoinOperator::GreaterThanOrEqual => JoinOperator::LessThanOrEqual,
580 }
581 }
582
583 fn compare_values(&self, left: &DataValue, right: &DataValue, op: &JoinOperator) -> bool {
585 match op {
586 JoinOperator::Equal => left == right,
587 JoinOperator::NotEqual => left != right,
588 JoinOperator::LessThan => left < right,
589 JoinOperator::GreaterThan => left > right,
590 JoinOperator::LessThanOrEqual => left <= right,
591 JoinOperator::GreaterThanOrEqual => left >= right,
592 }
593 }
594
595 fn nested_loop_join_inner(
597 &self,
598 left_table: Arc<DataTable>,
599 right_table: Arc<DataTable>,
600 left_col_idx: usize,
601 right_col_idx: usize,
602 operator: &JoinOperator,
603 ) -> Result<DataTable> {
604 let start = std::time::Instant::now();
605
606 info!(
607 "Executing nested loop INNER JOIN with {:?} operator: {} x {} rows",
608 operator,
609 left_table.row_count(),
610 right_table.row_count()
611 );
612
613 let mut result = DataTable::new("joined");
615
616 for col in &left_table.columns {
618 result.add_column(DataColumn {
619 name: col.name.clone(),
620 data_type: col.data_type.clone(),
621 nullable: col.nullable,
622 unique_values: col.unique_values,
623 null_count: col.null_count,
624 metadata: col.metadata.clone(),
625 qualified_name: col.qualified_name.clone(), source_table: col.source_table.clone(), });
628 }
629
630 for col in &right_table.columns {
632 if !left_table
633 .columns
634 .iter()
635 .any(|left_col| left_col.name == col.name)
636 {
637 result.add_column(DataColumn {
638 name: col.name.clone(),
639 data_type: col.data_type.clone(),
640 nullable: col.nullable,
641 unique_values: col.unique_values,
642 null_count: col.null_count,
643 metadata: col.metadata.clone(),
644 qualified_name: col.qualified_name.clone(), source_table: col.source_table.clone(), });
647 } else {
648 result.add_column(DataColumn {
649 name: format!("{}_right", col.name),
650 data_type: col.data_type.clone(),
651 nullable: col.nullable,
652 unique_values: col.unique_values,
653 null_count: col.null_count,
654 metadata: col.metadata.clone(),
655 qualified_name: col.qualified_name.clone(), source_table: col.source_table.clone(), });
658 }
659 }
660
661 let mut match_count = 0;
663 for left_row in &left_table.rows {
664 let left_value = &left_row.values[left_col_idx];
665
666 for right_row in &right_table.rows {
667 let right_value = &right_row.values[right_col_idx];
668
669 if self.compare_values(left_value, right_value, operator) {
670 let mut joined_row = DataRow { values: Vec::new() };
671 joined_row.values.extend_from_slice(&left_row.values);
672 joined_row.values.extend_from_slice(&right_row.values);
673 result.add_row(joined_row);
674 match_count += 1;
675 }
676 }
677 }
678
679 info!(
680 "Nested loop INNER JOIN complete: {} matches found in {:?}",
681 match_count,
682 start.elapsed()
683 );
684
685 Ok(result)
686 }
687
688 fn nested_loop_join_left(
690 &self,
691 left_table: Arc<DataTable>,
692 right_table: Arc<DataTable>,
693 left_col_idx: usize,
694 right_col_idx: usize,
695 operator: &JoinOperator,
696 ) -> Result<DataTable> {
697 let start = std::time::Instant::now();
698
699 info!(
700 "Executing nested loop LEFT JOIN with {:?} operator: {} x {} rows",
701 operator,
702 left_table.row_count(),
703 right_table.row_count()
704 );
705
706 let mut result = DataTable::new("joined");
708
709 for col in &left_table.columns {
711 result.add_column(DataColumn {
712 name: col.name.clone(),
713 data_type: col.data_type.clone(),
714 nullable: col.nullable,
715 unique_values: col.unique_values,
716 null_count: col.null_count,
717 metadata: col.metadata.clone(),
718 qualified_name: col.qualified_name.clone(), source_table: col.source_table.clone(), });
721 }
722
723 for col in &right_table.columns {
725 if !left_table
726 .columns
727 .iter()
728 .any(|left_col| left_col.name == col.name)
729 {
730 result.add_column(DataColumn {
731 name: col.name.clone(),
732 data_type: col.data_type.clone(),
733 nullable: true, unique_values: col.unique_values,
735 null_count: col.null_count,
736 metadata: col.metadata.clone(),
737 qualified_name: col.qualified_name.clone(), source_table: col.source_table.clone(), });
740 } else {
741 result.add_column(DataColumn {
742 name: format!("{}_right", col.name),
743 data_type: col.data_type.clone(),
744 nullable: true, unique_values: col.unique_values,
746 null_count: col.null_count,
747 metadata: col.metadata.clone(),
748 qualified_name: col.qualified_name.clone(), source_table: col.source_table.clone(), });
751 }
752 }
753
754 let mut match_count = 0;
756 let mut null_count = 0;
757
758 for left_row in &left_table.rows {
759 let left_value = &left_row.values[left_col_idx];
760 let mut found_match = false;
761
762 for right_row in &right_table.rows {
763 let right_value = &right_row.values[right_col_idx];
764
765 if self.compare_values(left_value, right_value, operator) {
766 let mut joined_row = DataRow { values: Vec::new() };
767 joined_row.values.extend_from_slice(&left_row.values);
768 joined_row.values.extend_from_slice(&right_row.values);
769 result.add_row(joined_row);
770 match_count += 1;
771 found_match = true;
772 }
773 }
774
775 if !found_match {
777 let mut joined_row = DataRow { values: Vec::new() };
778 joined_row.values.extend_from_slice(&left_row.values);
779 for _ in 0..right_table.column_count() {
780 joined_row.values.push(DataValue::Null);
781 }
782 result.add_row(joined_row);
783 null_count += 1;
784 }
785 }
786
787 info!(
788 "Nested loop LEFT JOIN complete: {} matches, {} nulls in {:?}",
789 match_count,
790 null_count,
791 start.elapsed()
792 );
793
794 Ok(result)
795 }
796}