1use anyhow::{anyhow, Result};
4use std::collections::HashMap;
5use std::sync::Arc;
6use tracing::{debug, info};
7
8use crate::data::datatable::{DataColumn, DataRow, DataTable, DataValue};
9use crate::sql::parser::ast::{JoinClause, JoinOperator, JoinType};
10
11pub struct HashJoinExecutor {
13 case_insensitive: bool,
14}
15
16impl HashJoinExecutor {
17 pub fn new(case_insensitive: bool) -> Self {
18 Self { case_insensitive }
19 }
20
21 pub fn execute_join(
23 &self,
24 left_table: Arc<DataTable>,
25 join_clause: &JoinClause,
26 right_table: Arc<DataTable>,
27 ) -> Result<DataTable> {
28 info!(
29 "Executing {:?} JOIN: {} rows x {} rows",
30 join_clause.join_type,
31 left_table.row_count(),
32 right_table.row_count()
33 );
34
35 let (left_col_name, right_col_name) = self.parse_join_columns(join_clause)?;
37
38 let (left_col_idx, right_col_idx) =
42 self.resolve_join_columns(&left_table, &right_table, &left_col_name, &right_col_name)?;
43
44 let use_hash_join = join_clause.condition.operator == JoinOperator::Equal;
46
47 match join_clause.join_type {
49 JoinType::Inner => {
50 if use_hash_join {
51 self.hash_join_inner(
52 left_table,
53 right_table,
54 left_col_idx,
55 right_col_idx,
56 &left_col_name,
57 &right_col_name,
58 )
59 } else {
60 self.nested_loop_join_inner(
61 left_table,
62 right_table,
63 left_col_idx,
64 right_col_idx,
65 &join_clause.condition.operator,
66 )
67 }
68 }
69 JoinType::Left => {
70 if use_hash_join {
71 self.hash_join_left(
72 left_table,
73 right_table,
74 left_col_idx,
75 right_col_idx,
76 &left_col_name,
77 &right_col_name,
78 )
79 } else {
80 self.nested_loop_join_left(
81 left_table,
82 right_table,
83 left_col_idx,
84 right_col_idx,
85 &join_clause.condition.operator,
86 )
87 }
88 }
89 JoinType::Right => {
90 if use_hash_join {
91 self.hash_join_left(
93 right_table,
94 left_table,
95 right_col_idx,
96 left_col_idx,
97 &right_col_name,
98 &left_col_name,
99 )
100 } else {
101 self.nested_loop_join_left(
103 right_table,
104 left_table,
105 right_col_idx,
106 left_col_idx,
107 &self.reverse_operator(&join_clause.condition.operator),
108 )
109 }
110 }
111 JoinType::Cross => self.cross_join(left_table, right_table),
112 JoinType::Full => {
113 return Err(anyhow!("FULL OUTER JOIN not yet implemented"));
114 }
115 }
116 }
117
118 fn parse_join_columns(&self, join_clause: &JoinClause) -> Result<(String, String)> {
120 Ok((
121 join_clause.condition.left_column.clone(),
122 join_clause.condition.right_column.clone(),
123 ))
124 }
125
126 fn resolve_join_columns(
128 &self,
129 left_table: &DataTable,
130 right_table: &DataTable,
131 left_col_name: &str,
132 right_col_name: &str,
133 ) -> Result<(usize, usize)> {
134 let left_col_idx = if let Ok(idx) = self.find_column_index(left_table, left_col_name) {
136 idx
137 } else if let Ok(idx) = self.find_column_index(right_table, left_col_name) {
138 return Err(anyhow!(
141 "Column '{}' found in right table but specified as left operand. \
142 Please rewrite the condition with columns in correct positions.",
143 left_col_name
144 ));
145 } else {
146 return Err(anyhow!(
147 "Column '{}' not found in either table",
148 left_col_name
149 ));
150 };
151
152 let right_col_idx = if let Ok(idx) = self.find_column_index(right_table, right_col_name) {
154 idx
155 } else if let Ok(idx) = self.find_column_index(left_table, right_col_name) {
156 return Err(anyhow!(
159 "Column '{}' found in left table but specified as right operand. \
160 Please rewrite the condition with columns in correct positions.",
161 right_col_name
162 ));
163 } else {
164 return Err(anyhow!(
165 "Column '{}' not found in either table",
166 right_col_name
167 ));
168 };
169
170 Ok((left_col_idx, right_col_idx))
171 }
172
173 fn find_column_index(&self, table: &DataTable, col_name: &str) -> Result<usize> {
175 let col_name = if let Some(dot_pos) = col_name.rfind('.') {
177 &col_name[dot_pos + 1..]
178 } else {
179 col_name
180 };
181
182 debug!(
183 "Looking for column '{}' in table with columns: {:?}",
184 col_name,
185 table.column_names()
186 );
187
188 table
189 .columns
190 .iter()
191 .position(|col| {
192 if self.case_insensitive {
193 col.name.to_lowercase() == col_name.to_lowercase()
194 } else {
195 col.name == col_name
196 }
197 })
198 .ok_or_else(|| anyhow!("Column '{}' not found in table", col_name))
199 }
200
201 fn hash_join_inner(
203 &self,
204 left_table: Arc<DataTable>,
205 right_table: Arc<DataTable>,
206 left_col_idx: usize,
207 right_col_idx: usize,
208 _left_col_name: &str,
209 _right_col_name: &str,
210 ) -> Result<DataTable> {
211 let start = std::time::Instant::now();
212
213 let (build_table, probe_table, build_col_idx, probe_col_idx, build_is_left) =
215 if left_table.row_count() <= right_table.row_count() {
216 (
217 left_table.clone(),
218 right_table.clone(),
219 left_col_idx,
220 right_col_idx,
221 true,
222 )
223 } else {
224 (
225 right_table.clone(),
226 left_table.clone(),
227 right_col_idx,
228 left_col_idx,
229 false,
230 )
231 };
232
233 debug!(
234 "Building hash index on {} table ({} rows)",
235 if build_is_left { "left" } else { "right" },
236 build_table.row_count()
237 );
238
239 let mut hash_index: HashMap<DataValue, Vec<usize>> = HashMap::new();
241 for (row_idx, row) in build_table.rows.iter().enumerate() {
242 let key = row.values[build_col_idx].clone();
243 hash_index.entry(key).or_default().push(row_idx);
244 }
245
246 debug!(
247 "Hash index built with {} unique keys in {:?}",
248 hash_index.len(),
249 start.elapsed()
250 );
251
252 let mut result = DataTable::new("joined");
254
255 for col in &left_table.columns {
257 result.add_column(DataColumn {
258 name: col.name.clone(),
259 data_type: col.data_type.clone(),
260 nullable: col.nullable,
261 unique_values: col.unique_values,
262 null_count: col.null_count,
263 metadata: col.metadata.clone(),
264 });
265 }
266
267 for col in &right_table.columns {
269 if !left_table
271 .columns
272 .iter()
273 .any(|left_col| left_col.name == col.name)
274 {
275 result.add_column(DataColumn {
276 name: col.name.clone(),
277 data_type: col.data_type.clone(),
278 nullable: col.nullable,
279 unique_values: col.unique_values,
280 null_count: col.null_count,
281 metadata: col.metadata.clone(),
282 });
283 } else {
284 result.add_column(DataColumn {
286 name: format!("{}_right", col.name),
287 data_type: col.data_type.clone(),
288 nullable: col.nullable,
289 unique_values: col.unique_values,
290 null_count: col.null_count,
291 metadata: col.metadata.clone(),
292 });
293 }
294 }
295
296 debug!(
297 "Joined table will have {} columns: {:?}",
298 result.column_count(),
299 result.column_names()
300 );
301
302 let mut match_count = 0;
304 for probe_row in &probe_table.rows {
305 let probe_key = &probe_row.values[probe_col_idx];
306
307 if let Some(matching_indices) = hash_index.get(probe_key) {
308 for &build_idx in matching_indices {
309 let build_row = &build_table.rows[build_idx];
310
311 let mut joined_row = DataRow { values: Vec::new() };
313
314 if build_is_left {
315 joined_row.values.extend_from_slice(&build_row.values);
317 joined_row.values.extend_from_slice(&probe_row.values);
318 } else {
319 joined_row.values.extend_from_slice(&probe_row.values);
321 joined_row.values.extend_from_slice(&build_row.values);
322 }
323
324 result.add_row(joined_row);
325 match_count += 1;
326 }
327 }
328 }
329
330 info!(
331 "INNER JOIN complete: {} matches found in {:?}",
332 match_count,
333 start.elapsed()
334 );
335
336 Ok(result)
337 }
338
339 fn hash_join_left(
341 &self,
342 left_table: Arc<DataTable>,
343 right_table: Arc<DataTable>,
344 left_col_idx: usize,
345 right_col_idx: usize,
346 _left_col_name: &str,
347 _right_col_name: &str,
348 ) -> Result<DataTable> {
349 let start = std::time::Instant::now();
350
351 debug!(
352 "Building hash index on right table ({} rows)",
353 right_table.row_count()
354 );
355
356 let mut hash_index: HashMap<DataValue, Vec<usize>> = HashMap::new();
358 for (row_idx, row) in right_table.rows.iter().enumerate() {
359 let key = row.values[right_col_idx].clone();
360 hash_index.entry(key).or_default().push(row_idx);
361 }
362
363 let mut result = DataTable::new("joined");
365
366 for col in &left_table.columns {
368 result.add_column(DataColumn {
369 name: col.name.clone(),
370 data_type: col.data_type.clone(),
371 nullable: col.nullable,
372 unique_values: col.unique_values,
373 null_count: col.null_count,
374 metadata: col.metadata.clone(),
375 });
376 }
377
378 for col in &right_table.columns {
380 if !left_table
382 .columns
383 .iter()
384 .any(|left_col| left_col.name == col.name)
385 {
386 result.add_column(DataColumn {
387 name: col.name.clone(),
388 data_type: col.data_type.clone(),
389 nullable: true, unique_values: col.unique_values,
391 null_count: col.null_count,
392 metadata: col.metadata.clone(),
393 });
394 } else {
395 result.add_column(DataColumn {
397 name: format!("{}_right", col.name),
398 data_type: col.data_type.clone(),
399 nullable: true, unique_values: col.unique_values,
401 null_count: col.null_count,
402 metadata: col.metadata.clone(),
403 });
404 }
405 }
406
407 debug!(
408 "LEFT JOIN table will have {} columns: {:?}",
409 result.column_count(),
410 result.column_names()
411 );
412
413 let mut match_count = 0;
415 let mut null_count = 0;
416
417 for left_row in &left_table.rows {
418 let left_key = &left_row.values[left_col_idx];
419
420 if let Some(matching_indices) = hash_index.get(left_key) {
421 for &right_idx in matching_indices {
423 let right_row = &right_table.rows[right_idx];
424
425 let mut joined_row = DataRow { values: Vec::new() };
426 joined_row.values.extend_from_slice(&left_row.values);
427 joined_row.values.extend_from_slice(&right_row.values);
428
429 result.add_row(joined_row);
430 match_count += 1;
431 }
432 } else {
433 let mut joined_row = DataRow { values: Vec::new() };
435 joined_row.values.extend_from_slice(&left_row.values);
436
437 for _ in 0..right_table.column_count() {
439 joined_row.values.push(DataValue::Null);
440 }
441
442 result.add_row(joined_row);
443 null_count += 1;
444 }
445 }
446
447 info!(
448 "LEFT JOIN complete: {} matches, {} nulls in {:?}",
449 match_count,
450 null_count,
451 start.elapsed()
452 );
453
454 Ok(result)
455 }
456
457 fn cross_join(
459 &self,
460 left_table: Arc<DataTable>,
461 right_table: Arc<DataTable>,
462 ) -> Result<DataTable> {
463 let start = std::time::Instant::now();
464
465 let result_rows = left_table.row_count() * right_table.row_count();
467 if result_rows > 1_000_000 {
468 return Err(anyhow!(
469 "CROSS JOIN would produce {} rows, which exceeds the safety limit",
470 result_rows
471 ));
472 }
473
474 let mut result = DataTable::new("joined");
476
477 for col in &left_table.columns {
479 result.add_column(col.clone());
480 }
481 for col in &right_table.columns {
482 result.add_column(col.clone());
483 }
484
485 for left_row in &left_table.rows {
487 for right_row in &right_table.rows {
488 let mut joined_row = DataRow { values: Vec::new() };
489 joined_row.values.extend_from_slice(&left_row.values);
490 joined_row.values.extend_from_slice(&right_row.values);
491 result.add_row(joined_row);
492 }
493 }
494
495 info!(
496 "CROSS JOIN complete: {} rows in {:?}",
497 result.row_count(),
498 start.elapsed()
499 );
500
501 Ok(result)
502 }
503
504 fn qualify_column_name(
506 &self,
507 col_name: &str,
508 table_side: &str,
509 left_join_col: &str,
510 right_join_col: &str,
511 ) -> String {
512 let base_name = if let Some(dot_pos) = col_name.rfind('.') {
514 &col_name[dot_pos + 1..]
515 } else {
516 col_name
517 };
518
519 let left_base = if let Some(dot_pos) = left_join_col.rfind('.') {
520 &left_join_col[dot_pos + 1..]
521 } else {
522 left_join_col
523 };
524
525 let right_base = if let Some(dot_pos) = right_join_col.rfind('.') {
526 &right_join_col[dot_pos + 1..]
527 } else {
528 right_join_col
529 };
530
531 if base_name == left_base || base_name == right_base {
533 format!("{}_{}", table_side, base_name)
534 } else {
535 col_name.to_string()
536 }
537 }
538
539 fn reverse_operator(&self, op: &JoinOperator) -> JoinOperator {
541 match op {
542 JoinOperator::Equal => JoinOperator::Equal,
543 JoinOperator::NotEqual => JoinOperator::NotEqual,
544 JoinOperator::LessThan => JoinOperator::GreaterThan,
545 JoinOperator::GreaterThan => JoinOperator::LessThan,
546 JoinOperator::LessThanOrEqual => JoinOperator::GreaterThanOrEqual,
547 JoinOperator::GreaterThanOrEqual => JoinOperator::LessThanOrEqual,
548 }
549 }
550
551 fn compare_values(&self, left: &DataValue, right: &DataValue, op: &JoinOperator) -> bool {
553 match op {
554 JoinOperator::Equal => left == right,
555 JoinOperator::NotEqual => left != right,
556 JoinOperator::LessThan => left < right,
557 JoinOperator::GreaterThan => left > right,
558 JoinOperator::LessThanOrEqual => left <= right,
559 JoinOperator::GreaterThanOrEqual => left >= right,
560 }
561 }
562
563 fn nested_loop_join_inner(
565 &self,
566 left_table: Arc<DataTable>,
567 right_table: Arc<DataTable>,
568 left_col_idx: usize,
569 right_col_idx: usize,
570 operator: &JoinOperator,
571 ) -> Result<DataTable> {
572 let start = std::time::Instant::now();
573
574 info!(
575 "Executing nested loop INNER JOIN with {:?} operator: {} x {} rows",
576 operator,
577 left_table.row_count(),
578 right_table.row_count()
579 );
580
581 let mut result = DataTable::new("joined");
583
584 for col in &left_table.columns {
586 result.add_column(DataColumn {
587 name: col.name.clone(),
588 data_type: col.data_type.clone(),
589 nullable: col.nullable,
590 unique_values: col.unique_values,
591 null_count: col.null_count,
592 metadata: col.metadata.clone(),
593 });
594 }
595
596 for col in &right_table.columns {
598 if !left_table
599 .columns
600 .iter()
601 .any(|left_col| left_col.name == col.name)
602 {
603 result.add_column(DataColumn {
604 name: col.name.clone(),
605 data_type: col.data_type.clone(),
606 nullable: col.nullable,
607 unique_values: col.unique_values,
608 null_count: col.null_count,
609 metadata: col.metadata.clone(),
610 });
611 } else {
612 result.add_column(DataColumn {
613 name: format!("{}_right", col.name),
614 data_type: col.data_type.clone(),
615 nullable: col.nullable,
616 unique_values: col.unique_values,
617 null_count: col.null_count,
618 metadata: col.metadata.clone(),
619 });
620 }
621 }
622
623 let mut match_count = 0;
625 for left_row in &left_table.rows {
626 let left_value = &left_row.values[left_col_idx];
627
628 for right_row in &right_table.rows {
629 let right_value = &right_row.values[right_col_idx];
630
631 if self.compare_values(left_value, right_value, operator) {
632 let mut joined_row = DataRow { values: Vec::new() };
633 joined_row.values.extend_from_slice(&left_row.values);
634 joined_row.values.extend_from_slice(&right_row.values);
635 result.add_row(joined_row);
636 match_count += 1;
637 }
638 }
639 }
640
641 info!(
642 "Nested loop INNER JOIN complete: {} matches found in {:?}",
643 match_count,
644 start.elapsed()
645 );
646
647 Ok(result)
648 }
649
650 fn nested_loop_join_left(
652 &self,
653 left_table: Arc<DataTable>,
654 right_table: Arc<DataTable>,
655 left_col_idx: usize,
656 right_col_idx: usize,
657 operator: &JoinOperator,
658 ) -> Result<DataTable> {
659 let start = std::time::Instant::now();
660
661 info!(
662 "Executing nested loop LEFT JOIN with {:?} operator: {} x {} rows",
663 operator,
664 left_table.row_count(),
665 right_table.row_count()
666 );
667
668 let mut result = DataTable::new("joined");
670
671 for col in &left_table.columns {
673 result.add_column(DataColumn {
674 name: col.name.clone(),
675 data_type: col.data_type.clone(),
676 nullable: col.nullable,
677 unique_values: col.unique_values,
678 null_count: col.null_count,
679 metadata: col.metadata.clone(),
680 });
681 }
682
683 for col in &right_table.columns {
685 if !left_table
686 .columns
687 .iter()
688 .any(|left_col| left_col.name == col.name)
689 {
690 result.add_column(DataColumn {
691 name: col.name.clone(),
692 data_type: col.data_type.clone(),
693 nullable: true, unique_values: col.unique_values,
695 null_count: col.null_count,
696 metadata: col.metadata.clone(),
697 });
698 } else {
699 result.add_column(DataColumn {
700 name: format!("{}_right", col.name),
701 data_type: col.data_type.clone(),
702 nullable: true, unique_values: col.unique_values,
704 null_count: col.null_count,
705 metadata: col.metadata.clone(),
706 });
707 }
708 }
709
710 let mut match_count = 0;
712 let mut null_count = 0;
713
714 for left_row in &left_table.rows {
715 let left_value = &left_row.values[left_col_idx];
716 let mut found_match = false;
717
718 for right_row in &right_table.rows {
719 let right_value = &right_row.values[right_col_idx];
720
721 if self.compare_values(left_value, right_value, operator) {
722 let mut joined_row = DataRow { values: Vec::new() };
723 joined_row.values.extend_from_slice(&left_row.values);
724 joined_row.values.extend_from_slice(&right_row.values);
725 result.add_row(joined_row);
726 match_count += 1;
727 found_match = true;
728 }
729 }
730
731 if !found_match {
733 let mut joined_row = DataRow { values: Vec::new() };
734 joined_row.values.extend_from_slice(&left_row.values);
735 for _ in 0..right_table.column_count() {
736 joined_row.values.push(DataValue::Null);
737 }
738 result.add_row(joined_row);
739 null_count += 1;
740 }
741 }
742
743 info!(
744 "Nested loop LEFT JOIN complete: {} matches, {} nulls in {:?}",
745 match_count,
746 null_count,
747 start.elapsed()
748 );
749
750 Ok(result)
751 }
752}