sql_cli/data/
hash_join.rs

1//! Hash join implementation for efficient JOIN operations
2
3use anyhow::{anyhow, Result};
4use std::collections::HashMap;
5use std::sync::Arc;
6use tracing::{debug, info};
7
8use crate::data::datatable::{DataColumn, DataRow, DataTable, DataValue};
9use crate::sql::parser::ast::{JoinClause, JoinOperator, JoinType};
10
11/// Hash join executor for efficient JOIN operations
12pub struct HashJoinExecutor {
13    case_insensitive: bool,
14}
15
16impl HashJoinExecutor {
17    pub fn new(case_insensitive: bool) -> Self {
18        Self { case_insensitive }
19    }
20
21    /// Execute a single join operation
22    pub fn execute_join(
23        &self,
24        left_table: Arc<DataTable>,
25        join_clause: &JoinClause,
26        right_table: Arc<DataTable>,
27    ) -> Result<DataTable> {
28        info!(
29            "Executing {:?} JOIN: {} rows x {} rows",
30            join_clause.join_type,
31            left_table.row_count(),
32            right_table.row_count()
33        );
34
35        // Extract column references from the join condition
36        let (left_col_name, right_col_name) = self.parse_join_columns(join_clause)?;
37
38        // Find column indices
39        let left_col_idx = self.find_column_index(&left_table, &left_col_name)?;
40        let right_col_idx = self.find_column_index(&right_table, &right_col_name)?;
41
42        // Perform the appropriate join based on type
43        match join_clause.join_type {
44            JoinType::Inner => self.hash_join_inner(
45                left_table,
46                right_table,
47                left_col_idx,
48                right_col_idx,
49                &left_col_name,
50                &right_col_name,
51            ),
52            JoinType::Left => self.hash_join_left(
53                left_table,
54                right_table,
55                left_col_idx,
56                right_col_idx,
57                &left_col_name,
58                &right_col_name,
59            ),
60            JoinType::Right => {
61                // Right join is just a left join with tables swapped
62                self.hash_join_left(
63                    right_table,
64                    left_table,
65                    right_col_idx,
66                    left_col_idx,
67                    &right_col_name,
68                    &left_col_name,
69                )
70            }
71            JoinType::Cross => self.cross_join(left_table, right_table),
72            JoinType::Full => {
73                return Err(anyhow!("FULL OUTER JOIN not yet implemented"));
74            }
75        }
76    }
77
78    /// Parse join columns from the join condition
79    fn parse_join_columns(&self, join_clause: &JoinClause) -> Result<(String, String)> {
80        // For now, we only support simple equality conditions
81        if join_clause.condition.operator != JoinOperator::Equal {
82            return Err(anyhow!(
83                "Only equality JOIN conditions are currently supported"
84            ));
85        }
86
87        Ok((
88            join_clause.condition.left_column.clone(),
89            join_clause.condition.right_column.clone(),
90        ))
91    }
92
93    /// Find column index in a table
94    fn find_column_index(&self, table: &DataTable, col_name: &str) -> Result<usize> {
95        // Handle table-qualified column names (e.g., "t1.id")
96        let col_name = if let Some(dot_pos) = col_name.rfind('.') {
97            &col_name[dot_pos + 1..]
98        } else {
99            col_name
100        };
101
102        table
103            .columns
104            .iter()
105            .position(|col| {
106                if self.case_insensitive {
107                    col.name.to_lowercase() == col_name.to_lowercase()
108                } else {
109                    col.name == col_name
110                }
111            })
112            .ok_or_else(|| anyhow!("Column '{}' not found in table", col_name))
113    }
114
115    /// Hash join implementation for INNER JOIN
116    fn hash_join_inner(
117        &self,
118        left_table: Arc<DataTable>,
119        right_table: Arc<DataTable>,
120        left_col_idx: usize,
121        right_col_idx: usize,
122        _left_col_name: &str,
123        _right_col_name: &str,
124    ) -> Result<DataTable> {
125        let start = std::time::Instant::now();
126
127        // Determine which table to use for building the hash index (prefer smaller)
128        let (build_table, probe_table, build_col_idx, probe_col_idx, build_is_left) =
129            if left_table.row_count() <= right_table.row_count() {
130                (
131                    left_table.clone(),
132                    right_table.clone(),
133                    left_col_idx,
134                    right_col_idx,
135                    true,
136                )
137            } else {
138                (
139                    right_table.clone(),
140                    left_table.clone(),
141                    right_col_idx,
142                    left_col_idx,
143                    false,
144                )
145            };
146
147        debug!(
148            "Building hash index on {} table ({} rows)",
149            if build_is_left { "left" } else { "right" },
150            build_table.row_count()
151        );
152
153        // Build hash index on the smaller table
154        let mut hash_index: HashMap<DataValue, Vec<usize>> = HashMap::new();
155        for (row_idx, row) in build_table.rows.iter().enumerate() {
156            let key = row.values[build_col_idx].clone();
157            hash_index.entry(key).or_default().push(row_idx);
158        }
159
160        debug!(
161            "Hash index built with {} unique keys in {:?}",
162            hash_index.len(),
163            start.elapsed()
164        );
165
166        // Create result table with columns from both tables
167        let mut result = DataTable::new("joined");
168
169        // Add columns from left table
170        for col in &left_table.columns {
171            result.add_column(DataColumn {
172                name: col.name.clone(),
173                data_type: col.data_type.clone(),
174                nullable: col.nullable,
175                unique_values: col.unique_values,
176                null_count: col.null_count,
177                metadata: col.metadata.clone(),
178            });
179        }
180
181        // Add columns from right table
182        for col in &right_table.columns {
183            // Skip columns with duplicate names for now
184            if !left_table
185                .columns
186                .iter()
187                .any(|left_col| left_col.name == col.name)
188            {
189                result.add_column(DataColumn {
190                    name: col.name.clone(),
191                    data_type: col.data_type.clone(),
192                    nullable: col.nullable,
193                    unique_values: col.unique_values,
194                    null_count: col.null_count,
195                    metadata: col.metadata.clone(),
196                });
197            } else {
198                // If there's a name conflict, add with a suffix
199                result.add_column(DataColumn {
200                    name: format!("{}_right", col.name),
201                    data_type: col.data_type.clone(),
202                    nullable: col.nullable,
203                    unique_values: col.unique_values,
204                    null_count: col.null_count,
205                    metadata: col.metadata.clone(),
206                });
207            }
208        }
209
210        debug!(
211            "Joined table will have {} columns: {:?}",
212            result.column_count(),
213            result.column_names()
214        );
215
216        // Probe phase: iterate through the larger table
217        let mut match_count = 0;
218        for probe_row in &probe_table.rows {
219            let probe_key = &probe_row.values[probe_col_idx];
220
221            if let Some(matching_indices) = hash_index.get(probe_key) {
222                for &build_idx in matching_indices {
223                    let build_row = &build_table.rows[build_idx];
224
225                    // Create joined row based on which table was used for building
226                    let mut joined_row = DataRow { values: Vec::new() };
227
228                    if build_is_left {
229                        // Build was left, probe was right
230                        joined_row.values.extend_from_slice(&build_row.values);
231                        joined_row.values.extend_from_slice(&probe_row.values);
232                    } else {
233                        // Build was right, probe was left
234                        joined_row.values.extend_from_slice(&probe_row.values);
235                        joined_row.values.extend_from_slice(&build_row.values);
236                    }
237
238                    result.add_row(joined_row);
239                    match_count += 1;
240                }
241            }
242        }
243
244        info!(
245            "INNER JOIN complete: {} matches found in {:?}",
246            match_count,
247            start.elapsed()
248        );
249
250        Ok(result)
251    }
252
253    /// Hash join implementation for LEFT OUTER JOIN
254    fn hash_join_left(
255        &self,
256        left_table: Arc<DataTable>,
257        right_table: Arc<DataTable>,
258        left_col_idx: usize,
259        right_col_idx: usize,
260        _left_col_name: &str,
261        _right_col_name: &str,
262    ) -> Result<DataTable> {
263        let start = std::time::Instant::now();
264
265        debug!(
266            "Building hash index on right table ({} rows)",
267            right_table.row_count()
268        );
269
270        // Build hash index on right table
271        let mut hash_index: HashMap<DataValue, Vec<usize>> = HashMap::new();
272        for (row_idx, row) in right_table.rows.iter().enumerate() {
273            let key = row.values[right_col_idx].clone();
274            hash_index.entry(key).or_default().push(row_idx);
275        }
276
277        // Create result table with columns from both tables
278        let mut result = DataTable::new("joined");
279
280        // Add columns from left table
281        for col in &left_table.columns {
282            result.add_column(DataColumn {
283                name: col.name.clone(),
284                data_type: col.data_type.clone(),
285                nullable: col.nullable,
286                unique_values: col.unique_values,
287                null_count: col.null_count,
288                metadata: col.metadata.clone(),
289            });
290        }
291
292        // Add columns from right table (all nullable for LEFT JOIN)
293        for col in &right_table.columns {
294            // Skip columns with duplicate names for now
295            if !left_table
296                .columns
297                .iter()
298                .any(|left_col| left_col.name == col.name)
299            {
300                result.add_column(DataColumn {
301                    name: col.name.clone(),
302                    data_type: col.data_type.clone(),
303                    nullable: true, // Always nullable for outer join
304                    unique_values: col.unique_values,
305                    null_count: col.null_count,
306                    metadata: col.metadata.clone(),
307                });
308            } else {
309                // If there's a name conflict, add with a suffix
310                result.add_column(DataColumn {
311                    name: format!("{}_right", col.name),
312                    data_type: col.data_type.clone(),
313                    nullable: true, // Always nullable for outer join
314                    unique_values: col.unique_values,
315                    null_count: col.null_count,
316                    metadata: col.metadata.clone(),
317                });
318            }
319        }
320
321        debug!(
322            "LEFT JOIN table will have {} columns: {:?}",
323            result.column_count(),
324            result.column_names()
325        );
326
327        // Probe phase: iterate through left table
328        let mut match_count = 0;
329        let mut null_count = 0;
330
331        for left_row in &left_table.rows {
332            let left_key = &left_row.values[left_col_idx];
333
334            if let Some(matching_indices) = hash_index.get(left_key) {
335                // Found matches - emit joined rows
336                for &right_idx in matching_indices {
337                    let right_row = &right_table.rows[right_idx];
338
339                    let mut joined_row = DataRow { values: Vec::new() };
340                    joined_row.values.extend_from_slice(&left_row.values);
341                    joined_row.values.extend_from_slice(&right_row.values);
342
343                    result.add_row(joined_row);
344                    match_count += 1;
345                }
346            } else {
347                // No match - emit left row with NULLs for right columns
348                let mut joined_row = DataRow { values: Vec::new() };
349                joined_row.values.extend_from_slice(&left_row.values);
350
351                // Add NULL values for all right table columns
352                for _ in 0..right_table.column_count() {
353                    joined_row.values.push(DataValue::Null);
354                }
355
356                result.add_row(joined_row);
357                null_count += 1;
358            }
359        }
360
361        info!(
362            "LEFT JOIN complete: {} matches, {} nulls in {:?}",
363            match_count,
364            null_count,
365            start.elapsed()
366        );
367
368        Ok(result)
369    }
370
371    /// Cross join implementation
372    fn cross_join(
373        &self,
374        left_table: Arc<DataTable>,
375        right_table: Arc<DataTable>,
376    ) -> Result<DataTable> {
377        let start = std::time::Instant::now();
378
379        // Check for potential memory explosion
380        let result_rows = left_table.row_count() * right_table.row_count();
381        if result_rows > 1_000_000 {
382            return Err(anyhow!(
383                "CROSS JOIN would produce {} rows, which exceeds the safety limit",
384                result_rows
385            ));
386        }
387
388        // Create result table
389        let mut result = DataTable::new("joined");
390
391        // Add columns from both tables
392        for col in &left_table.columns {
393            result.add_column(col.clone());
394        }
395        for col in &right_table.columns {
396            result.add_column(col.clone());
397        }
398
399        // Generate Cartesian product
400        for left_row in &left_table.rows {
401            for right_row in &right_table.rows {
402                let mut joined_row = DataRow { values: Vec::new() };
403                joined_row.values.extend_from_slice(&left_row.values);
404                joined_row.values.extend_from_slice(&right_row.values);
405                result.add_row(joined_row);
406            }
407        }
408
409        info!(
410            "CROSS JOIN complete: {} rows in {:?}",
411            result.row_count(),
412            start.elapsed()
413        );
414
415        Ok(result)
416    }
417
418    /// Qualify column name to avoid conflicts
419    fn qualify_column_name(
420        &self,
421        col_name: &str,
422        table_side: &str,
423        left_join_col: &str,
424        right_join_col: &str,
425    ) -> String {
426        // Extract base column name (without table prefix)
427        let base_name = if let Some(dot_pos) = col_name.rfind('.') {
428            &col_name[dot_pos + 1..]
429        } else {
430            col_name
431        };
432
433        let left_base = if let Some(dot_pos) = left_join_col.rfind('.') {
434            &left_join_col[dot_pos + 1..]
435        } else {
436            left_join_col
437        };
438
439        let right_base = if let Some(dot_pos) = right_join_col.rfind('.') {
440            &right_join_col[dot_pos + 1..]
441        } else {
442            right_join_col
443        };
444
445        // If this column name appears in both join columns, qualify it
446        if base_name == left_base || base_name == right_base {
447            format!("{}_{}", table_side, base_name)
448        } else {
449            col_name.to_string()
450        }
451    }
452}