sql_cli/data/
hash_join.rs

1//! Hash join implementation for efficient JOIN operations
2
3use anyhow::{anyhow, Result};
4use std::collections::HashMap;
5use std::sync::Arc;
6use tracing::{debug, info};
7
8use crate::data::datatable::{DataColumn, DataRow, DataTable, DataValue};
9use crate::sql::parser::ast::{JoinClause, JoinOperator, JoinType};
10
11/// Hash join executor for efficient JOIN operations
12pub struct HashJoinExecutor {
13    case_insensitive: bool,
14}
15
16impl HashJoinExecutor {
17    pub fn new(case_insensitive: bool) -> Self {
18        Self { case_insensitive }
19    }
20
21    /// Execute a single join operation
22    pub fn execute_join(
23        &self,
24        left_table: Arc<DataTable>,
25        join_clause: &JoinClause,
26        right_table: Arc<DataTable>,
27    ) -> Result<DataTable> {
28        info!(
29            "Executing {:?} JOIN: {} rows x {} rows",
30            join_clause.join_type,
31            left_table.row_count(),
32            right_table.row_count()
33        );
34
35        // Extract column references from the join condition
36        let (left_col_name, right_col_name) = self.parse_join_columns(join_clause)?;
37
38        // For join conditions, we need to be smart about which table each column belongs to
39        // The left column could reference either table, same with right column
40        // Try to resolve based on table prefixes or column existence
41        let (left_col_idx, right_col_idx) =
42            self.resolve_join_columns(&left_table, &right_table, &left_col_name, &right_col_name)?;
43
44        // Choose join algorithm based on operator
45        let use_hash_join = join_clause.condition.operator == JoinOperator::Equal;
46
47        // Perform the appropriate join based on type and operator
48        match join_clause.join_type {
49            JoinType::Inner => {
50                if use_hash_join {
51                    self.hash_join_inner(
52                        left_table,
53                        right_table,
54                        left_col_idx,
55                        right_col_idx,
56                        &left_col_name,
57                        &right_col_name,
58                    )
59                } else {
60                    self.nested_loop_join_inner(
61                        left_table,
62                        right_table,
63                        left_col_idx,
64                        right_col_idx,
65                        &join_clause.condition.operator,
66                    )
67                }
68            }
69            JoinType::Left => {
70                if use_hash_join {
71                    self.hash_join_left(
72                        left_table,
73                        right_table,
74                        left_col_idx,
75                        right_col_idx,
76                        &left_col_name,
77                        &right_col_name,
78                    )
79                } else {
80                    self.nested_loop_join_left(
81                        left_table,
82                        right_table,
83                        left_col_idx,
84                        right_col_idx,
85                        &join_clause.condition.operator,
86                    )
87                }
88            }
89            JoinType::Right => {
90                if use_hash_join {
91                    // Right join is just a left join with tables swapped
92                    self.hash_join_left(
93                        right_table,
94                        left_table,
95                        right_col_idx,
96                        left_col_idx,
97                        &right_col_name,
98                        &left_col_name,
99                    )
100                } else {
101                    // Right join is just a left join with tables swapped
102                    self.nested_loop_join_left(
103                        right_table,
104                        left_table,
105                        right_col_idx,
106                        left_col_idx,
107                        &self.reverse_operator(&join_clause.condition.operator),
108                    )
109                }
110            }
111            JoinType::Cross => self.cross_join(left_table, right_table),
112            JoinType::Full => {
113                return Err(anyhow!("FULL OUTER JOIN not yet implemented"));
114            }
115        }
116    }
117
118    /// Parse join columns from the join condition
119    fn parse_join_columns(&self, join_clause: &JoinClause) -> Result<(String, String)> {
120        Ok((
121            join_clause.condition.left_column.clone(),
122            join_clause.condition.right_column.clone(),
123        ))
124    }
125
126    /// Resolve which table each column belongs to in a join condition
127    fn resolve_join_columns(
128        &self,
129        left_table: &DataTable,
130        right_table: &DataTable,
131        left_col_name: &str,
132        right_col_name: &str,
133    ) -> Result<(usize, usize)> {
134        // Try to find the left column in left table, then right table
135        let left_col_idx = if let Ok(idx) = self.find_column_index(left_table, left_col_name) {
136            idx
137        } else if let Ok(idx) = self.find_column_index(right_table, left_col_name) {
138            // The "left" column in the condition is actually from the right table
139            // This means we need to swap the comparison
140            return Err(anyhow!(
141                "Column '{}' found in right table but specified as left operand. \
142                Please rewrite the condition with columns in correct positions.",
143                left_col_name
144            ));
145        } else {
146            return Err(anyhow!(
147                "Column '{}' not found in either table",
148                left_col_name
149            ));
150        };
151
152        // Try to find the right column in right table, then left table
153        let right_col_idx = if let Ok(idx) = self.find_column_index(right_table, right_col_name) {
154            idx
155        } else if let Ok(idx) = self.find_column_index(left_table, right_col_name) {
156            // The "right" column in the condition is actually from the left table
157            // This means we need to swap the comparison
158            return Err(anyhow!(
159                "Column '{}' found in left table but specified as right operand. \
160                Please rewrite the condition with columns in correct positions.",
161                right_col_name
162            ));
163        } else {
164            return Err(anyhow!(
165                "Column '{}' not found in either table",
166                right_col_name
167            ));
168        };
169
170        Ok((left_col_idx, right_col_idx))
171    }
172
173    /// Find column index in a table
174    fn find_column_index(&self, table: &DataTable, col_name: &str) -> Result<usize> {
175        // Handle table-qualified column names (e.g., "t1.id")
176        let col_name = if let Some(dot_pos) = col_name.rfind('.') {
177            &col_name[dot_pos + 1..]
178        } else {
179            col_name
180        };
181
182        debug!(
183            "Looking for column '{}' in table with columns: {:?}",
184            col_name,
185            table.column_names()
186        );
187
188        table
189            .columns
190            .iter()
191            .position(|col| {
192                if self.case_insensitive {
193                    col.name.to_lowercase() == col_name.to_lowercase()
194                } else {
195                    col.name == col_name
196                }
197            })
198            .ok_or_else(|| anyhow!("Column '{}' not found in table", col_name))
199    }
200
201    /// Hash join implementation for INNER JOIN
202    fn hash_join_inner(
203        &self,
204        left_table: Arc<DataTable>,
205        right_table: Arc<DataTable>,
206        left_col_idx: usize,
207        right_col_idx: usize,
208        _left_col_name: &str,
209        _right_col_name: &str,
210    ) -> Result<DataTable> {
211        let start = std::time::Instant::now();
212
213        // Determine which table to use for building the hash index (prefer smaller)
214        let (build_table, probe_table, build_col_idx, probe_col_idx, build_is_left) =
215            if left_table.row_count() <= right_table.row_count() {
216                (
217                    left_table.clone(),
218                    right_table.clone(),
219                    left_col_idx,
220                    right_col_idx,
221                    true,
222                )
223            } else {
224                (
225                    right_table.clone(),
226                    left_table.clone(),
227                    right_col_idx,
228                    left_col_idx,
229                    false,
230                )
231            };
232
233        debug!(
234            "Building hash index on {} table ({} rows)",
235            if build_is_left { "left" } else { "right" },
236            build_table.row_count()
237        );
238
239        // Build hash index on the smaller table
240        let mut hash_index: HashMap<DataValue, Vec<usize>> = HashMap::new();
241        for (row_idx, row) in build_table.rows.iter().enumerate() {
242            let key = row.values[build_col_idx].clone();
243            hash_index.entry(key).or_default().push(row_idx);
244        }
245
246        debug!(
247            "Hash index built with {} unique keys in {:?}",
248            hash_index.len(),
249            start.elapsed()
250        );
251
252        // Create result table with columns from both tables
253        let mut result = DataTable::new("joined");
254
255        // Add columns from left table
256        for col in &left_table.columns {
257            result.add_column(DataColumn {
258                name: col.name.clone(),
259                data_type: col.data_type.clone(),
260                nullable: col.nullable,
261                unique_values: col.unique_values,
262                null_count: col.null_count,
263                metadata: col.metadata.clone(),
264            });
265        }
266
267        // Add columns from right table
268        for col in &right_table.columns {
269            // Skip columns with duplicate names for now
270            if !left_table
271                .columns
272                .iter()
273                .any(|left_col| left_col.name == col.name)
274            {
275                result.add_column(DataColumn {
276                    name: col.name.clone(),
277                    data_type: col.data_type.clone(),
278                    nullable: col.nullable,
279                    unique_values: col.unique_values,
280                    null_count: col.null_count,
281                    metadata: col.metadata.clone(),
282                });
283            } else {
284                // If there's a name conflict, add with a suffix
285                result.add_column(DataColumn {
286                    name: format!("{}_right", col.name),
287                    data_type: col.data_type.clone(),
288                    nullable: col.nullable,
289                    unique_values: col.unique_values,
290                    null_count: col.null_count,
291                    metadata: col.metadata.clone(),
292                });
293            }
294        }
295
296        debug!(
297            "Joined table will have {} columns: {:?}",
298            result.column_count(),
299            result.column_names()
300        );
301
302        // Probe phase: iterate through the larger table
303        let mut match_count = 0;
304        for probe_row in &probe_table.rows {
305            let probe_key = &probe_row.values[probe_col_idx];
306
307            if let Some(matching_indices) = hash_index.get(probe_key) {
308                for &build_idx in matching_indices {
309                    let build_row = &build_table.rows[build_idx];
310
311                    // Create joined row based on which table was used for building
312                    let mut joined_row = DataRow { values: Vec::new() };
313
314                    if build_is_left {
315                        // Build was left, probe was right
316                        joined_row.values.extend_from_slice(&build_row.values);
317                        joined_row.values.extend_from_slice(&probe_row.values);
318                    } else {
319                        // Build was right, probe was left
320                        joined_row.values.extend_from_slice(&probe_row.values);
321                        joined_row.values.extend_from_slice(&build_row.values);
322                    }
323
324                    result.add_row(joined_row);
325                    match_count += 1;
326                }
327            }
328        }
329
330        info!(
331            "INNER JOIN complete: {} matches found in {:?}",
332            match_count,
333            start.elapsed()
334        );
335
336        Ok(result)
337    }
338
339    /// Hash join implementation for LEFT OUTER JOIN
340    fn hash_join_left(
341        &self,
342        left_table: Arc<DataTable>,
343        right_table: Arc<DataTable>,
344        left_col_idx: usize,
345        right_col_idx: usize,
346        _left_col_name: &str,
347        _right_col_name: &str,
348    ) -> Result<DataTable> {
349        let start = std::time::Instant::now();
350
351        debug!(
352            "Building hash index on right table ({} rows)",
353            right_table.row_count()
354        );
355
356        // Build hash index on right table
357        let mut hash_index: HashMap<DataValue, Vec<usize>> = HashMap::new();
358        for (row_idx, row) in right_table.rows.iter().enumerate() {
359            let key = row.values[right_col_idx].clone();
360            hash_index.entry(key).or_default().push(row_idx);
361        }
362
363        // Create result table with columns from both tables
364        let mut result = DataTable::new("joined");
365
366        // Add columns from left table
367        for col in &left_table.columns {
368            result.add_column(DataColumn {
369                name: col.name.clone(),
370                data_type: col.data_type.clone(),
371                nullable: col.nullable,
372                unique_values: col.unique_values,
373                null_count: col.null_count,
374                metadata: col.metadata.clone(),
375            });
376        }
377
378        // Add columns from right table (all nullable for LEFT JOIN)
379        for col in &right_table.columns {
380            // Skip columns with duplicate names for now
381            if !left_table
382                .columns
383                .iter()
384                .any(|left_col| left_col.name == col.name)
385            {
386                result.add_column(DataColumn {
387                    name: col.name.clone(),
388                    data_type: col.data_type.clone(),
389                    nullable: true, // Always nullable for outer join
390                    unique_values: col.unique_values,
391                    null_count: col.null_count,
392                    metadata: col.metadata.clone(),
393                });
394            } else {
395                // If there's a name conflict, add with a suffix
396                result.add_column(DataColumn {
397                    name: format!("{}_right", col.name),
398                    data_type: col.data_type.clone(),
399                    nullable: true, // Always nullable for outer join
400                    unique_values: col.unique_values,
401                    null_count: col.null_count,
402                    metadata: col.metadata.clone(),
403                });
404            }
405        }
406
407        debug!(
408            "LEFT JOIN table will have {} columns: {:?}",
409            result.column_count(),
410            result.column_names()
411        );
412
413        // Probe phase: iterate through left table
414        let mut match_count = 0;
415        let mut null_count = 0;
416
417        for left_row in &left_table.rows {
418            let left_key = &left_row.values[left_col_idx];
419
420            if let Some(matching_indices) = hash_index.get(left_key) {
421                // Found matches - emit joined rows
422                for &right_idx in matching_indices {
423                    let right_row = &right_table.rows[right_idx];
424
425                    let mut joined_row = DataRow { values: Vec::new() };
426                    joined_row.values.extend_from_slice(&left_row.values);
427                    joined_row.values.extend_from_slice(&right_row.values);
428
429                    result.add_row(joined_row);
430                    match_count += 1;
431                }
432            } else {
433                // No match - emit left row with NULLs for right columns
434                let mut joined_row = DataRow { values: Vec::new() };
435                joined_row.values.extend_from_slice(&left_row.values);
436
437                // Add NULL values for all right table columns
438                for _ in 0..right_table.column_count() {
439                    joined_row.values.push(DataValue::Null);
440                }
441
442                result.add_row(joined_row);
443                null_count += 1;
444            }
445        }
446
447        info!(
448            "LEFT JOIN complete: {} matches, {} nulls in {:?}",
449            match_count,
450            null_count,
451            start.elapsed()
452        );
453
454        Ok(result)
455    }
456
457    /// Cross join implementation
458    fn cross_join(
459        &self,
460        left_table: Arc<DataTable>,
461        right_table: Arc<DataTable>,
462    ) -> Result<DataTable> {
463        let start = std::time::Instant::now();
464
465        // Check for potential memory explosion
466        let result_rows = left_table.row_count() * right_table.row_count();
467        if result_rows > 1_000_000 {
468            return Err(anyhow!(
469                "CROSS JOIN would produce {} rows, which exceeds the safety limit",
470                result_rows
471            ));
472        }
473
474        // Create result table
475        let mut result = DataTable::new("joined");
476
477        // Add columns from both tables
478        for col in &left_table.columns {
479            result.add_column(col.clone());
480        }
481        for col in &right_table.columns {
482            result.add_column(col.clone());
483        }
484
485        // Generate Cartesian product
486        for left_row in &left_table.rows {
487            for right_row in &right_table.rows {
488                let mut joined_row = DataRow { values: Vec::new() };
489                joined_row.values.extend_from_slice(&left_row.values);
490                joined_row.values.extend_from_slice(&right_row.values);
491                result.add_row(joined_row);
492            }
493        }
494
495        info!(
496            "CROSS JOIN complete: {} rows in {:?}",
497            result.row_count(),
498            start.elapsed()
499        );
500
501        Ok(result)
502    }
503
504    /// Qualify column name to avoid conflicts
505    fn qualify_column_name(
506        &self,
507        col_name: &str,
508        table_side: &str,
509        left_join_col: &str,
510        right_join_col: &str,
511    ) -> String {
512        // Extract base column name (without table prefix)
513        let base_name = if let Some(dot_pos) = col_name.rfind('.') {
514            &col_name[dot_pos + 1..]
515        } else {
516            col_name
517        };
518
519        let left_base = if let Some(dot_pos) = left_join_col.rfind('.') {
520            &left_join_col[dot_pos + 1..]
521        } else {
522            left_join_col
523        };
524
525        let right_base = if let Some(dot_pos) = right_join_col.rfind('.') {
526            &right_join_col[dot_pos + 1..]
527        } else {
528            right_join_col
529        };
530
531        // If this column name appears in both join columns, qualify it
532        if base_name == left_base || base_name == right_base {
533            format!("{}_{}", table_side, base_name)
534        } else {
535            col_name.to_string()
536        }
537    }
538
539    /// Reverse a join operator for right joins
540    fn reverse_operator(&self, op: &JoinOperator) -> JoinOperator {
541        match op {
542            JoinOperator::Equal => JoinOperator::Equal,
543            JoinOperator::NotEqual => JoinOperator::NotEqual,
544            JoinOperator::LessThan => JoinOperator::GreaterThan,
545            JoinOperator::GreaterThan => JoinOperator::LessThan,
546            JoinOperator::LessThanOrEqual => JoinOperator::GreaterThanOrEqual,
547            JoinOperator::GreaterThanOrEqual => JoinOperator::LessThanOrEqual,
548        }
549    }
550
551    /// Compare two values based on the join operator
552    fn compare_values(&self, left: &DataValue, right: &DataValue, op: &JoinOperator) -> bool {
553        match op {
554            JoinOperator::Equal => left == right,
555            JoinOperator::NotEqual => left != right,
556            JoinOperator::LessThan => left < right,
557            JoinOperator::GreaterThan => left > right,
558            JoinOperator::LessThanOrEqual => left <= right,
559            JoinOperator::GreaterThanOrEqual => left >= right,
560        }
561    }
562
563    /// Nested loop join for INNER JOIN with inequality conditions
564    fn nested_loop_join_inner(
565        &self,
566        left_table: Arc<DataTable>,
567        right_table: Arc<DataTable>,
568        left_col_idx: usize,
569        right_col_idx: usize,
570        operator: &JoinOperator,
571    ) -> Result<DataTable> {
572        let start = std::time::Instant::now();
573
574        info!(
575            "Executing nested loop INNER JOIN with {:?} operator: {} x {} rows",
576            operator,
577            left_table.row_count(),
578            right_table.row_count()
579        );
580
581        // Create result table with columns from both tables
582        let mut result = DataTable::new("joined");
583
584        // Add columns from left table
585        for col in &left_table.columns {
586            result.add_column(DataColumn {
587                name: col.name.clone(),
588                data_type: col.data_type.clone(),
589                nullable: col.nullable,
590                unique_values: col.unique_values,
591                null_count: col.null_count,
592                metadata: col.metadata.clone(),
593            });
594        }
595
596        // Add columns from right table
597        for col in &right_table.columns {
598            if !left_table
599                .columns
600                .iter()
601                .any(|left_col| left_col.name == col.name)
602            {
603                result.add_column(DataColumn {
604                    name: col.name.clone(),
605                    data_type: col.data_type.clone(),
606                    nullable: col.nullable,
607                    unique_values: col.unique_values,
608                    null_count: col.null_count,
609                    metadata: col.metadata.clone(),
610                });
611            } else {
612                result.add_column(DataColumn {
613                    name: format!("{}_right", col.name),
614                    data_type: col.data_type.clone(),
615                    nullable: col.nullable,
616                    unique_values: col.unique_values,
617                    null_count: col.null_count,
618                    metadata: col.metadata.clone(),
619                });
620            }
621        }
622
623        // Nested loop join
624        let mut match_count = 0;
625        for left_row in &left_table.rows {
626            let left_value = &left_row.values[left_col_idx];
627
628            for right_row in &right_table.rows {
629                let right_value = &right_row.values[right_col_idx];
630
631                if self.compare_values(left_value, right_value, operator) {
632                    let mut joined_row = DataRow { values: Vec::new() };
633                    joined_row.values.extend_from_slice(&left_row.values);
634                    joined_row.values.extend_from_slice(&right_row.values);
635                    result.add_row(joined_row);
636                    match_count += 1;
637                }
638            }
639        }
640
641        info!(
642            "Nested loop INNER JOIN complete: {} matches found in {:?}",
643            match_count,
644            start.elapsed()
645        );
646
647        Ok(result)
648    }
649
650    /// Nested loop join for LEFT JOIN with inequality conditions
651    fn nested_loop_join_left(
652        &self,
653        left_table: Arc<DataTable>,
654        right_table: Arc<DataTable>,
655        left_col_idx: usize,
656        right_col_idx: usize,
657        operator: &JoinOperator,
658    ) -> Result<DataTable> {
659        let start = std::time::Instant::now();
660
661        info!(
662            "Executing nested loop LEFT JOIN with {:?} operator: {} x {} rows",
663            operator,
664            left_table.row_count(),
665            right_table.row_count()
666        );
667
668        // Create result table with columns from both tables
669        let mut result = DataTable::new("joined");
670
671        // Add columns from left table
672        for col in &left_table.columns {
673            result.add_column(DataColumn {
674                name: col.name.clone(),
675                data_type: col.data_type.clone(),
676                nullable: col.nullable,
677                unique_values: col.unique_values,
678                null_count: col.null_count,
679                metadata: col.metadata.clone(),
680            });
681        }
682
683        // Add columns from right table (all nullable for LEFT JOIN)
684        for col in &right_table.columns {
685            if !left_table
686                .columns
687                .iter()
688                .any(|left_col| left_col.name == col.name)
689            {
690                result.add_column(DataColumn {
691                    name: col.name.clone(),
692                    data_type: col.data_type.clone(),
693                    nullable: true, // Always nullable for outer join
694                    unique_values: col.unique_values,
695                    null_count: col.null_count,
696                    metadata: col.metadata.clone(),
697                });
698            } else {
699                result.add_column(DataColumn {
700                    name: format!("{}_right", col.name),
701                    data_type: col.data_type.clone(),
702                    nullable: true, // Always nullable for outer join
703                    unique_values: col.unique_values,
704                    null_count: col.null_count,
705                    metadata: col.metadata.clone(),
706                });
707            }
708        }
709
710        // Nested loop join
711        let mut match_count = 0;
712        let mut null_count = 0;
713
714        for left_row in &left_table.rows {
715            let left_value = &left_row.values[left_col_idx];
716            let mut found_match = false;
717
718            for right_row in &right_table.rows {
719                let right_value = &right_row.values[right_col_idx];
720
721                if self.compare_values(left_value, right_value, operator) {
722                    let mut joined_row = DataRow { values: Vec::new() };
723                    joined_row.values.extend_from_slice(&left_row.values);
724                    joined_row.values.extend_from_slice(&right_row.values);
725                    result.add_row(joined_row);
726                    match_count += 1;
727                    found_match = true;
728                }
729            }
730
731            // If no match found, emit left row with NULLs for right columns
732            if !found_match {
733                let mut joined_row = DataRow { values: Vec::new() };
734                joined_row.values.extend_from_slice(&left_row.values);
735                for _ in 0..right_table.column_count() {
736                    joined_row.values.push(DataValue::Null);
737                }
738                result.add_row(joined_row);
739                null_count += 1;
740            }
741        }
742
743        info!(
744            "Nested loop LEFT JOIN complete: {} matches, {} nulls in {:?}",
745            match_count,
746            null_count,
747            start.elapsed()
748        );
749
750        Ok(result)
751    }
752}