vibesql_executor/select/join/
hash_join_iterator.rs

1//! Hash join iterator implementation for lazy evaluation
2//!
3//! This module implements an iterator-based hash join that provides O(N+M)
4//! performance while maintaining lazy evaluation for the left (probe) side.
5
6#![allow(clippy::manual_is_multiple_of)]
7
8use ahash::AHashMap;
9
10use super::{combine_rows, FromResult};
11use crate::{
12    errors::ExecutorError,
13    schema::CombinedSchema,
14    select::RowIterator,
15    timeout::{TimeoutContext, CHECK_INTERVAL},
16};
17
18/// Hash join iterator that lazily produces joined rows
19///
20/// This implementation uses a hash join algorithm with:
21/// - Lazy left (probe) side: rows consumed on-demand from iterator
22/// - Materialized right (build) side: all rows hashed into HashMap
23///
24/// Algorithm:
25/// 1. Build phase: Materialize right side into hash table (one-time cost)
26/// 2. Probe phase: Stream left rows, hash lookup for matches (O(1) per row)
27///
28/// Performance: O(N + M) where N=left rows, M=right rows
29///
30/// Memory: O(M) for right side hash table + O(K) for current matches
31pub struct HashJoinIterator<L: RowIterator> {
32    /// Lazy probe side (left)
33    left: L,
34    /// Materialized build side (right) - hash table mapping join key to rows
35    right_hash_table: AHashMap<vibesql_types::SqlValue, Vec<vibesql_storage::Row>>,
36    /// Combined schema for output rows
37    schema: CombinedSchema,
38    /// Column index in left table for join key
39    left_col_idx: usize,
40    /// Column index in right table for join key
41    #[allow(dead_code)]
42    right_col_idx: usize,
43    /// Current left row being processed
44    current_left_row: Option<vibesql_storage::Row>,
45    /// Matching right rows for current left row
46    current_matches: Vec<vibesql_storage::Row>,
47    /// Index into current_matches
48    match_index: usize,
49    /// Number of right columns (for NULL padding)
50    #[allow(dead_code)]
51    right_col_count: usize,
52    /// Timeout context for query timeout enforcement
53    timeout_ctx: TimeoutContext,
54    /// Iteration counter for periodic timeout checks
55    iteration_count: usize,
56}
57
58impl<L: RowIterator> HashJoinIterator<L> {
59    /// Create a new hash join iterator for INNER JOIN
60    ///
61    /// # Arguments
62    /// * `left` - Lazy iterator for left (probe) side
63    /// * `right` - Materialized right (build) side
64    /// * `left_col_idx` - Column index in left table for join key
65    /// * `right_col_idx` - Column index in right table for join key
66    ///
67    /// # Returns
68    /// * `Ok(HashJoinIterator)` - Successfully created iterator
69    /// * `Err(ExecutorError)` - Failed due to memory limits or schema issues
70    #[allow(private_interfaces)]
71    pub fn new(
72        left: L,
73        right: FromResult,
74        left_col_idx: usize,
75        right_col_idx: usize,
76    ) -> Result<Self, ExecutorError> {
77        // Extract right table schema
78        let right_table_name = right
79            .schema
80            .table_schemas
81            .keys()
82            .next()
83            .ok_or_else(|| ExecutorError::UnsupportedFeature("Complex JOIN".to_string()))?
84            .clone();
85
86        let right_schema = right
87            .schema
88            .table_schemas
89            .get(&right_table_name)
90            .ok_or_else(|| ExecutorError::UnsupportedFeature("Complex JOIN".to_string()))?
91            .1
92            .clone();
93
94        let right_col_count = right_schema.columns.len();
95
96        // Combine schemas (left schema from iterator + right schema)
97        let combined_schema =
98            CombinedSchema::combine(left.schema().clone(), right_table_name, right_schema);
99
100        // Use default timeout context (proper propagation from SelectExecutor is a future improvement)
101        let timeout_ctx = TimeoutContext::new_default();
102
103        // Build phase: Create hash table from right side
104        // This is the one-time materialization cost
105        let mut hash_table: AHashMap<vibesql_types::SqlValue, Vec<vibesql_storage::Row>> =
106            AHashMap::new();
107        let mut build_iterations = 0;
108
109        for row in right.into_rows() {
110            // Check timeout periodically during build phase
111            build_iterations += 1;
112            if build_iterations % CHECK_INTERVAL == 0 {
113                timeout_ctx.check()?;
114            }
115
116            let key = row.values[right_col_idx].clone();
117
118            // Skip NULL values - they never match in equi-joins
119            if key != vibesql_types::SqlValue::Null {
120                hash_table.entry(key).or_default().push(row);
121            }
122        }
123
124        Ok(Self {
125            left,
126            right_hash_table: hash_table,
127            schema: combined_schema,
128            left_col_idx,
129            right_col_idx,
130            current_left_row: None,
131            current_matches: Vec::new(),
132            match_index: 0,
133            right_col_count,
134            timeout_ctx,
135            iteration_count: 0,
136        })
137    }
138
139    /// Get the number of rows in the hash table (right side)
140    pub fn hash_table_size(&self) -> usize {
141        self.right_hash_table.values().map(|v| v.len()).sum()
142    }
143}
144
145impl<L: RowIterator> Iterator for HashJoinIterator<L> {
146    type Item = Result<vibesql_storage::Row, ExecutorError>;
147
148    fn next(&mut self) -> Option<Self::Item> {
149        loop {
150            // Check timeout periodically during probe phase
151            self.iteration_count += 1;
152            if self.iteration_count % CHECK_INTERVAL == 0 {
153                if let Err(e) = self.timeout_ctx.check() {
154                    return Some(Err(e));
155                }
156            }
157
158            // If we have remaining matches for current left row, return next match
159            if self.match_index < self.current_matches.len() {
160                let right_row = &self.current_matches[self.match_index];
161                self.match_index += 1;
162
163                // Combine left and right rows
164                if let Some(ref left_row) = self.current_left_row {
165                    let combined_row = combine_rows(left_row, right_row);
166                    return Some(Ok(combined_row));
167                }
168            }
169
170            // No more matches for current left row, get next left row
171            match self.left.next() {
172                Some(Ok(left_row)) => {
173                    let key = &left_row.values[self.left_col_idx];
174
175                    // Skip NULL values - they never match in equi-joins
176                    if key == &vibesql_types::SqlValue::Null {
177                        // For INNER JOIN, skip rows with NULL join keys
178                        continue;
179                    }
180
181                    // Lookup matches in hash table
182                    if let Some(matches) = self.right_hash_table.get(key) {
183                        // Found matches - set up for iteration
184                        self.current_left_row = Some(left_row);
185                        self.current_matches = matches.clone();
186                        self.match_index = 0;
187                        // Continue loop to return first match
188                    } else {
189                        // No matches for this left row
190                        // For INNER JOIN, skip this row
191                        continue;
192                    }
193                }
194                Some(Err(e)) => {
195                    // Propagate error from left iterator
196                    return Some(Err(e));
197                }
198                None => {
199                    // Left iterator exhausted, we're done
200                    return None;
201                }
202            }
203        }
204    }
205}
206
207impl<L: RowIterator> RowIterator for HashJoinIterator<L> {
208    fn schema(&self) -> &CombinedSchema {
209        &self.schema
210    }
211}
212
213#[cfg(test)]
214mod tests {
215    use super::*;
216    use crate::select::TableScanIterator;
217    use vibesql_catalog::{ColumnSchema, TableSchema};
218    use vibesql_storage::Row;
219    use vibesql_types::{DataType, SqlValue};
220
221    /// Helper to create a simple FromResult for testing
222    fn create_test_from_result(
223        table_name: &str,
224        columns: Vec<(&str, DataType)>,
225        rows: Vec<Vec<SqlValue>>,
226    ) -> FromResult {
227        let schema = TableSchema::new(
228            table_name.to_string(),
229            columns
230                .iter()
231                .map(|(name, dtype)| {
232                    ColumnSchema::new(
233                        name.to_string(),
234                        dtype.clone(),
235                        true, // nullable
236                    )
237                })
238                .collect(),
239        );
240
241        let combined_schema = CombinedSchema::from_table(table_name.to_string(), schema);
242        let rows = rows.into_iter().map(Row::new).collect();
243
244        FromResult::from_rows(combined_schema, rows)
245    }
246
247    #[test]
248    fn test_hash_join_iterator_simple() {
249        // Left table: users(id, name)
250        let left_result = create_test_from_result(
251            "users",
252            vec![("id", DataType::Integer), ("name", DataType::Varchar { max_length: Some(50) })],
253            vec![
254                vec![SqlValue::Integer(1), SqlValue::Varchar("Alice".to_string())],
255                vec![SqlValue::Integer(2), SqlValue::Varchar("Bob".to_string())],
256                vec![SqlValue::Integer(3), SqlValue::Varchar("Charlie".to_string())],
257            ],
258        );
259
260        let left_iter = TableScanIterator::new(left_result.schema.clone(), left_result.into_rows());
261
262        // Right table: orders(user_id, amount)
263        let right = create_test_from_result(
264            "orders",
265            vec![("user_id", DataType::Integer), ("amount", DataType::Integer)],
266            vec![
267                vec![SqlValue::Integer(1), SqlValue::Integer(100)],
268                vec![SqlValue::Integer(2), SqlValue::Integer(200)],
269                vec![SqlValue::Integer(1), SqlValue::Integer(150)],
270            ],
271        );
272
273        // Join on users.id = orders.user_id (column 0 from both sides)
274        let join_iter = HashJoinIterator::new(left_iter, right, 0, 0).unwrap();
275
276        // Collect results
277        let results: Result<Vec<_>, _> = join_iter.collect();
278        let results = results.unwrap();
279
280        // Should have 3 rows (user 1 has 2 orders, user 2 has 1 order, user 3 has no orders)
281        assert_eq!(results.len(), 3);
282
283        // Verify combined rows have correct structure (4 columns: id, name, user_id, amount)
284        for row in &results {
285            assert_eq!(row.values.len(), 4);
286        }
287
288        // Check specific matches
289        // Alice (id=1) should appear twice (2 orders)
290        let alice_orders: Vec<_> =
291            results.iter().filter(|r| r.values[0] == SqlValue::Integer(1)).collect();
292        assert_eq!(alice_orders.len(), 2);
293
294        // Bob (id=2) should appear once (1 order)
295        let bob_orders: Vec<_> =
296            results.iter().filter(|r| r.values[0] == SqlValue::Integer(2)).collect();
297        assert_eq!(bob_orders.len(), 1);
298
299        // Charlie (id=3) should not appear (no orders)
300        let charlie_orders: Vec<_> =
301            results.iter().filter(|r| r.values[0] == SqlValue::Integer(3)).collect();
302        assert_eq!(charlie_orders.len(), 0);
303    }
304
305    #[test]
306    fn test_hash_join_iterator_null_values() {
307        // Left table with NULL id
308        let left_result = create_test_from_result(
309            "users",
310            vec![("id", DataType::Integer), ("name", DataType::Varchar { max_length: Some(50) })],
311            vec![
312                vec![SqlValue::Integer(1), SqlValue::Varchar("Alice".to_string())],
313                vec![SqlValue::Null, SqlValue::Varchar("Unknown".to_string())],
314            ],
315        );
316
317        let left_iter = TableScanIterator::new(left_result.schema.clone(), left_result.into_rows());
318
319        // Right table with NULL user_id
320        let right = create_test_from_result(
321            "orders",
322            vec![("user_id", DataType::Integer), ("amount", DataType::Integer)],
323            vec![
324                vec![SqlValue::Integer(1), SqlValue::Integer(100)],
325                vec![SqlValue::Null, SqlValue::Integer(200)],
326            ],
327        );
328
329        let join_iter = HashJoinIterator::new(left_iter, right, 0, 0).unwrap();
330
331        let results: Result<Vec<_>, _> = join_iter.collect();
332        let results = results.unwrap();
333
334        // Only one match: Alice (id=1) with order (user_id=1)
335        // NULLs should not match each other in equi-joins
336        assert_eq!(results.len(), 1);
337        assert_eq!(results[0].values[0], SqlValue::Integer(1)); // user id
338        assert_eq!(results[0].values[1], SqlValue::Varchar("Alice".to_string())); // user name
339        assert_eq!(results[0].values[2], SqlValue::Integer(1)); // order user_id
340        assert_eq!(results[0].values[3], SqlValue::Integer(100)); // order amount
341    }
342
343    #[test]
344    fn test_hash_join_iterator_no_matches() {
345        // Left table
346        let left_result = create_test_from_result(
347            "users",
348            vec![("id", DataType::Integer)],
349            vec![vec![SqlValue::Integer(1)], vec![SqlValue::Integer(2)]],
350        );
351
352        let left_iter = TableScanIterator::new(left_result.schema.clone(), left_result.into_rows());
353
354        // Right table with non-matching ids
355        let right = create_test_from_result(
356            "orders",
357            vec![("user_id", DataType::Integer)],
358            vec![vec![SqlValue::Integer(3)], vec![SqlValue::Integer(4)]],
359        );
360
361        let join_iter = HashJoinIterator::new(left_iter, right, 0, 0).unwrap();
362
363        let results: Result<Vec<_>, _> = join_iter.collect();
364        let results = results.unwrap();
365
366        // No matches
367        assert_eq!(results.len(), 0);
368    }
369
370    #[test]
371    fn test_hash_join_iterator_empty_tables() {
372        // Left table (empty)
373        let left_result = create_test_from_result("users", vec![("id", DataType::Integer)], vec![]);
374
375        let left_iter = TableScanIterator::new(left_result.schema.clone(), left_result.into_rows());
376
377        // Right table (empty)
378        let right = create_test_from_result("orders", vec![("user_id", DataType::Integer)], vec![]);
379
380        let join_iter = HashJoinIterator::new(left_iter, right, 0, 0).unwrap();
381
382        let results: Result<Vec<_>, _> = join_iter.collect();
383        let results = results.unwrap();
384
385        // No rows
386        assert_eq!(results.len(), 0);
387    }
388
389    #[test]
390    fn test_hash_join_iterator_duplicate_keys() {
391        // Left table with duplicate ids
392        let left_result = create_test_from_result(
393            "users",
394            vec![("id", DataType::Integer), ("type", DataType::Varchar { max_length: Some(10) })],
395            vec![
396                vec![SqlValue::Integer(1), SqlValue::Varchar("admin".to_string())],
397                vec![SqlValue::Integer(1), SqlValue::Varchar("user".to_string())],
398            ],
399        );
400
401        let left_iter = TableScanIterator::new(left_result.schema.clone(), left_result.into_rows());
402
403        // Right table with duplicate user_ids
404        let right = create_test_from_result(
405            "orders",
406            vec![("user_id", DataType::Integer), ("amount", DataType::Integer)],
407            vec![
408                vec![SqlValue::Integer(1), SqlValue::Integer(100)],
409                vec![SqlValue::Integer(1), SqlValue::Integer(200)],
410            ],
411        );
412
413        let join_iter = HashJoinIterator::new(left_iter, right, 0, 0).unwrap();
414
415        let results: Result<Vec<_>, _> = join_iter.collect();
416        let results = results.unwrap();
417
418        // Cartesian product of matching keys: 2 left rows * 2 right rows = 4 results
419        assert_eq!(results.len(), 4);
420
421        // All should have id=1
422        for row in &results {
423            assert_eq!(row.values[0], SqlValue::Integer(1));
424        }
425    }
426
427    #[test]
428    fn test_hash_join_iterator_lazy_evaluation() {
429        // This test verifies that the left side is truly lazy
430        // We'll create an iterator that tracks how many rows have been consumed
431
432        struct CountingIterator {
433            schema: CombinedSchema,
434            rows: Vec<Row>,
435            index: usize,
436            consumed_count: std::sync::Arc<std::sync::Mutex<usize>>,
437        }
438
439        impl Iterator for CountingIterator {
440            type Item = Result<Row, ExecutorError>;
441
442            fn next(&mut self) -> Option<Self::Item> {
443                if self.index < self.rows.len() {
444                    let row = self.rows[self.index].clone();
445                    self.index += 1;
446                    *self.consumed_count.lock().unwrap() += 1;
447                    Some(Ok(row))
448                } else {
449                    None
450                }
451            }
452        }
453
454        impl RowIterator for CountingIterator {
455            fn schema(&self) -> &CombinedSchema {
456                &self.schema
457            }
458        }
459
460        let consumed = std::sync::Arc::new(std::sync::Mutex::new(0));
461
462        let left_result = create_test_from_result(
463            "users",
464            vec![("id", DataType::Integer)],
465            vec![
466                vec![SqlValue::Integer(1)],
467                vec![SqlValue::Integer(2)],
468                vec![SqlValue::Integer(3)],
469                vec![SqlValue::Integer(4)],
470                vec![SqlValue::Integer(5)],
471            ],
472        );
473
474        let counting_iter = CountingIterator {
475            schema: left_result.schema.clone(),
476            rows: left_result.into_rows(),
477            index: 0,
478            consumed_count: consumed.clone(),
479        };
480
481        let right = create_test_from_result(
482            "orders",
483            vec![("user_id", DataType::Integer)],
484            vec![vec![SqlValue::Integer(1)], vec![SqlValue::Integer(2)]],
485        );
486
487        let join_iter = HashJoinIterator::new(counting_iter, right, 0, 0).unwrap();
488
489        // Take only first 2 results
490        let results: Vec<_> = join_iter.take(2).collect::<Result<Vec<_>, _>>().unwrap();
491        assert_eq!(results.len(), 2);
492
493        // Verify that we didn't consume all left rows (lazy evaluation)
494        // We should have consumed at most 2 rows (matching ids 1 and 2)
495        let consumed_count = *consumed.lock().unwrap();
496        assert!(consumed_count <= 3, "Expected at most 3 rows consumed, got {}", consumed_count);
497    }
498}