vibesql_executor/select/join/
hash_join_iterator.rs

1//! Hash join iterator implementation for lazy evaluation
2//!
3//! This module implements an iterator-based hash join that provides O(N+M)
4//! performance while maintaining lazy evaluation for the left (probe) side.
5
6#![allow(clippy::manual_is_multiple_of)]
7
8use ahash::AHashMap;
9
10use super::{combine_rows, FromResult};
11use crate::{
12    errors::ExecutorError,
13    schema::CombinedSchema,
14    select::RowIterator,
15    timeout::{TimeoutContext, CHECK_INTERVAL},
16};
17
18/// Hash join iterator that lazily produces joined rows
19///
20/// This implementation uses a hash join algorithm with:
21/// - Lazy left (probe) side: rows consumed on-demand from iterator
22/// - Materialized right (build) side: all rows hashed into HashMap
23///
24/// Algorithm:
25/// 1. Build phase: Materialize right side into hash table (one-time cost)
26/// 2. Probe phase: Stream left rows, hash lookup for matches (O(1) per row)
27///
28/// Performance: O(N + M) where N=left rows, M=right rows
29///
30/// Memory: O(M) for right side hash table + O(K) for current matches
31pub struct HashJoinIterator<L: RowIterator> {
32    /// Lazy probe side (left)
33    left: L,
34    /// Materialized build side (right) - hash table mapping join key to rows
35    right_hash_table: AHashMap<vibesql_types::SqlValue, Vec<vibesql_storage::Row>>,
36    /// Combined schema for output rows
37    schema: CombinedSchema,
38    /// Column index in left table for join key
39    left_col_idx: usize,
40    /// Column index in right table for join key
41    #[allow(dead_code)]
42    right_col_idx: usize,
43    /// Current left row being processed
44    current_left_row: Option<vibesql_storage::Row>,
45    /// Matching right rows for current left row
46    current_matches: Vec<vibesql_storage::Row>,
47    /// Index into current_matches
48    match_index: usize,
49    /// Number of right columns (for NULL padding)
50    #[allow(dead_code)]
51    right_col_count: usize,
52    /// Timeout context for query timeout enforcement
53    timeout_ctx: TimeoutContext,
54    /// Iteration counter for periodic timeout checks
55    iteration_count: usize,
56}
57
58impl<L: RowIterator> HashJoinIterator<L> {
59    /// Create a new hash join iterator for INNER JOIN
60    ///
61    /// # Arguments
62    /// * `left` - Lazy iterator for left (probe) side
63    /// * `right` - Materialized right (build) side
64    /// * `left_col_idx` - Column index in left table for join key
65    /// * `right_col_idx` - Column index in right table for join key
66    ///
67    /// # Returns
68    /// * `Ok(HashJoinIterator)` - Successfully created iterator
69    /// * `Err(ExecutorError)` - Failed due to memory limits or schema issues
70    #[allow(private_interfaces)]
71    pub fn new(
72        left: L,
73        right: FromResult,
74        left_col_idx: usize,
75        right_col_idx: usize,
76    ) -> Result<Self, ExecutorError> {
77        // Extract right table schema
78        let right_table_name = right
79            .schema
80            .table_schemas
81            .keys()
82            .next()
83            .ok_or_else(|| ExecutorError::UnsupportedFeature("Complex JOIN".to_string()))?
84            .clone();
85
86        let right_schema = right
87            .schema
88            .table_schemas
89            .get(&right_table_name)
90            .ok_or_else(|| ExecutorError::UnsupportedFeature("Complex JOIN".to_string()))?
91            .1
92            .clone();
93
94        let right_col_count = right_schema.columns.len();
95
96        // Combine schemas (left schema from iterator + right schema)
97        let combined_schema =
98            CombinedSchema::combine(left.schema().clone(), right_table_name, right_schema);
99
100        // Use default timeout context (proper propagation from SelectExecutor is a future improvement)
101        let timeout_ctx = TimeoutContext::new_default();
102
103        // Build phase: Create hash table from right side
104        // This is the one-time materialization cost
105        let mut hash_table: AHashMap<vibesql_types::SqlValue, Vec<vibesql_storage::Row>> = AHashMap::new();
106        let mut build_iterations = 0;
107
108        for row in right.into_rows() {
109            // Check timeout periodically during build phase
110            build_iterations += 1;
111            if build_iterations % CHECK_INTERVAL == 0 {
112                timeout_ctx.check()?;
113            }
114
115            let key = row.values[right_col_idx].clone();
116
117            // Skip NULL values - they never match in equi-joins
118            if key != vibesql_types::SqlValue::Null {
119                hash_table.entry(key).or_default().push(row);
120            }
121        }
122
123        Ok(Self {
124            left,
125            right_hash_table: hash_table,
126            schema: combined_schema,
127            left_col_idx,
128            right_col_idx,
129            current_left_row: None,
130            current_matches: Vec::new(),
131            match_index: 0,
132            right_col_count,
133            timeout_ctx,
134            iteration_count: 0,
135        })
136    }
137
138    /// Get the number of rows in the hash table (right side)
139    pub fn hash_table_size(&self) -> usize {
140        self.right_hash_table.values().map(|v| v.len()).sum()
141    }
142}
143
144impl<L: RowIterator> Iterator for HashJoinIterator<L> {
145    type Item = Result<vibesql_storage::Row, ExecutorError>;
146
147    fn next(&mut self) -> Option<Self::Item> {
148        loop {
149            // Check timeout periodically during probe phase
150            self.iteration_count += 1;
151            if self.iteration_count % CHECK_INTERVAL == 0 {
152                if let Err(e) = self.timeout_ctx.check() {
153                    return Some(Err(e));
154                }
155            }
156
157            // If we have remaining matches for current left row, return next match
158            if self.match_index < self.current_matches.len() {
159                let right_row = &self.current_matches[self.match_index];
160                self.match_index += 1;
161
162                // Combine left and right rows
163                if let Some(ref left_row) = self.current_left_row {
164                    let combined_row = combine_rows(left_row, right_row);
165                    return Some(Ok(combined_row));
166                }
167            }
168
169            // No more matches for current left row, get next left row
170            match self.left.next() {
171                Some(Ok(left_row)) => {
172                    let key = &left_row.values[self.left_col_idx];
173
174                    // Skip NULL values - they never match in equi-joins
175                    if key == &vibesql_types::SqlValue::Null {
176                        // For INNER JOIN, skip rows with NULL join keys
177                        continue;
178                    }
179
180                    // Lookup matches in hash table
181                    if let Some(matches) = self.right_hash_table.get(key) {
182                        // Found matches - set up for iteration
183                        self.current_left_row = Some(left_row);
184                        self.current_matches = matches.clone();
185                        self.match_index = 0;
186                        // Continue loop to return first match
187                    } else {
188                        // No matches for this left row
189                        // For INNER JOIN, skip this row
190                        continue;
191                    }
192                }
193                Some(Err(e)) => {
194                    // Propagate error from left iterator
195                    return Some(Err(e));
196                }
197                None => {
198                    // Left iterator exhausted, we're done
199                    return None;
200                }
201            }
202        }
203    }
204}
205
206impl<L: RowIterator> RowIterator for HashJoinIterator<L> {
207    fn schema(&self) -> &CombinedSchema {
208        &self.schema
209    }
210}
211
212#[cfg(test)]
213mod tests {
214    use super::*;
215    use vibesql_catalog::{ColumnSchema, TableSchema};
216    use crate::select::TableScanIterator;
217    use vibesql_storage::Row;
218    use vibesql_types::{DataType, SqlValue};
219
220    /// Helper to create a simple FromResult for testing
221    fn create_test_from_result(
222        table_name: &str,
223        columns: Vec<(&str, DataType)>,
224        rows: Vec<Vec<SqlValue>>,
225    ) -> FromResult {
226        let schema = TableSchema::new(
227            table_name.to_string(),
228            columns
229                .iter()
230                .map(|(name, dtype)| {
231                    ColumnSchema::new(
232                        name.to_string(),
233                        dtype.clone(),
234                        true, // nullable
235                    )
236                })
237                .collect(),
238        );
239
240        let combined_schema = CombinedSchema::from_table(table_name.to_string(), schema);
241        let rows = rows.into_iter().map(Row::new).collect();
242
243        FromResult::from_rows(combined_schema, rows)
244    }
245
246    #[test]
247    fn test_hash_join_iterator_simple() {
248        // Left table: users(id, name)
249        let left_result = create_test_from_result(
250            "users",
251            vec![("id", DataType::Integer), ("name", DataType::Varchar { max_length: Some(50) })],
252            vec![
253                vec![SqlValue::Integer(1), SqlValue::Varchar("Alice".to_string())],
254                vec![SqlValue::Integer(2), SqlValue::Varchar("Bob".to_string())],
255                vec![SqlValue::Integer(3), SqlValue::Varchar("Charlie".to_string())],
256            ],
257        );
258
259        let left_iter = TableScanIterator::new(left_result.schema.clone(), left_result.into_rows());
260
261        // Right table: orders(user_id, amount)
262        let right = create_test_from_result(
263            "orders",
264            vec![("user_id", DataType::Integer), ("amount", DataType::Integer)],
265            vec![
266                vec![SqlValue::Integer(1), SqlValue::Integer(100)],
267                vec![SqlValue::Integer(2), SqlValue::Integer(200)],
268                vec![SqlValue::Integer(1), SqlValue::Integer(150)],
269            ],
270        );
271
272        // Join on users.id = orders.user_id (column 0 from both sides)
273        let join_iter = HashJoinIterator::new(left_iter, right, 0, 0).unwrap();
274
275        // Collect results
276        let results: Result<Vec<_>, _> = join_iter.collect();
277        let results = results.unwrap();
278
279        // Should have 3 rows (user 1 has 2 orders, user 2 has 1 order, user 3 has no orders)
280        assert_eq!(results.len(), 3);
281
282        // Verify combined rows have correct structure (4 columns: id, name, user_id, amount)
283        for row in &results {
284            assert_eq!(row.values.len(), 4);
285        }
286
287        // Check specific matches
288        // Alice (id=1) should appear twice (2 orders)
289        let alice_orders: Vec<_> =
290            results.iter().filter(|r| r.values[0] == SqlValue::Integer(1)).collect();
291        assert_eq!(alice_orders.len(), 2);
292
293        // Bob (id=2) should appear once (1 order)
294        let bob_orders: Vec<_> =
295            results.iter().filter(|r| r.values[0] == SqlValue::Integer(2)).collect();
296        assert_eq!(bob_orders.len(), 1);
297
298        // Charlie (id=3) should not appear (no orders)
299        let charlie_orders: Vec<_> =
300            results.iter().filter(|r| r.values[0] == SqlValue::Integer(3)).collect();
301        assert_eq!(charlie_orders.len(), 0);
302    }
303
304    #[test]
305    fn test_hash_join_iterator_null_values() {
306        // Left table with NULL id
307        let left_result = create_test_from_result(
308            "users",
309            vec![("id", DataType::Integer), ("name", DataType::Varchar { max_length: Some(50) })],
310            vec![
311                vec![SqlValue::Integer(1), SqlValue::Varchar("Alice".to_string())],
312                vec![SqlValue::Null, SqlValue::Varchar("Unknown".to_string())],
313            ],
314        );
315
316        let left_iter = TableScanIterator::new(left_result.schema.clone(), left_result.into_rows());
317
318        // Right table with NULL user_id
319        let right = create_test_from_result(
320            "orders",
321            vec![("user_id", DataType::Integer), ("amount", DataType::Integer)],
322            vec![
323                vec![SqlValue::Integer(1), SqlValue::Integer(100)],
324                vec![SqlValue::Null, SqlValue::Integer(200)],
325            ],
326        );
327
328        let join_iter = HashJoinIterator::new(left_iter, right, 0, 0).unwrap();
329
330        let results: Result<Vec<_>, _> = join_iter.collect();
331        let results = results.unwrap();
332
333        // Only one match: Alice (id=1) with order (user_id=1)
334        // NULLs should not match each other in equi-joins
335        assert_eq!(results.len(), 1);
336        assert_eq!(results[0].values[0], SqlValue::Integer(1)); // user id
337        assert_eq!(results[0].values[1], SqlValue::Varchar("Alice".to_string())); // user name
338        assert_eq!(results[0].values[2], SqlValue::Integer(1)); // order user_id
339        assert_eq!(results[0].values[3], SqlValue::Integer(100)); // order amount
340    }
341
342    #[test]
343    fn test_hash_join_iterator_no_matches() {
344        // Left table
345        let left_result = create_test_from_result(
346            "users",
347            vec![("id", DataType::Integer)],
348            vec![vec![SqlValue::Integer(1)], vec![SqlValue::Integer(2)]],
349        );
350
351        let left_iter = TableScanIterator::new(left_result.schema.clone(), left_result.into_rows());
352
353        // Right table with non-matching ids
354        let right = create_test_from_result(
355            "orders",
356            vec![("user_id", DataType::Integer)],
357            vec![vec![SqlValue::Integer(3)], vec![SqlValue::Integer(4)]],
358        );
359
360        let join_iter = HashJoinIterator::new(left_iter, right, 0, 0).unwrap();
361
362        let results: Result<Vec<_>, _> = join_iter.collect();
363        let results = results.unwrap();
364
365        // No matches
366        assert_eq!(results.len(), 0);
367    }
368
369    #[test]
370    fn test_hash_join_iterator_empty_tables() {
371        // Left table (empty)
372        let left_result = create_test_from_result("users", vec![("id", DataType::Integer)], vec![]);
373
374        let left_iter = TableScanIterator::new(left_result.schema.clone(), left_result.into_rows());
375
376        // Right table (empty)
377        let right = create_test_from_result("orders", vec![("user_id", DataType::Integer)], vec![]);
378
379        let join_iter = HashJoinIterator::new(left_iter, right, 0, 0).unwrap();
380
381        let results: Result<Vec<_>, _> = join_iter.collect();
382        let results = results.unwrap();
383
384        // No rows
385        assert_eq!(results.len(), 0);
386    }
387
388    #[test]
389    fn test_hash_join_iterator_duplicate_keys() {
390        // Left table with duplicate ids
391        let left_result = create_test_from_result(
392            "users",
393            vec![("id", DataType::Integer), ("type", DataType::Varchar { max_length: Some(10) })],
394            vec![
395                vec![SqlValue::Integer(1), SqlValue::Varchar("admin".to_string())],
396                vec![SqlValue::Integer(1), SqlValue::Varchar("user".to_string())],
397            ],
398        );
399
400        let left_iter = TableScanIterator::new(left_result.schema.clone(), left_result.into_rows());
401
402        // Right table with duplicate user_ids
403        let right = create_test_from_result(
404            "orders",
405            vec![("user_id", DataType::Integer), ("amount", DataType::Integer)],
406            vec![
407                vec![SqlValue::Integer(1), SqlValue::Integer(100)],
408                vec![SqlValue::Integer(1), SqlValue::Integer(200)],
409            ],
410        );
411
412        let join_iter = HashJoinIterator::new(left_iter, right, 0, 0).unwrap();
413
414        let results: Result<Vec<_>, _> = join_iter.collect();
415        let results = results.unwrap();
416
417        // Cartesian product of matching keys: 2 left rows * 2 right rows = 4 results
418        assert_eq!(results.len(), 4);
419
420        // All should have id=1
421        for row in &results {
422            assert_eq!(row.values[0], SqlValue::Integer(1));
423        }
424    }
425
426    #[test]
427    fn test_hash_join_iterator_lazy_evaluation() {
428        // This test verifies that the left side is truly lazy
429        // We'll create an iterator that tracks how many rows have been consumed
430
431        struct CountingIterator {
432            schema: CombinedSchema,
433            rows: Vec<Row>,
434            index: usize,
435            consumed_count: std::sync::Arc<std::sync::Mutex<usize>>,
436        }
437
438        impl Iterator for CountingIterator {
439            type Item = Result<Row, ExecutorError>;
440
441            fn next(&mut self) -> Option<Self::Item> {
442                if self.index < self.rows.len() {
443                    let row = self.rows[self.index].clone();
444                    self.index += 1;
445                    *self.consumed_count.lock().unwrap() += 1;
446                    Some(Ok(row))
447                } else {
448                    None
449                }
450            }
451        }
452
453        impl RowIterator for CountingIterator {
454            fn schema(&self) -> &CombinedSchema {
455                &self.schema
456            }
457        }
458
459        let consumed = std::sync::Arc::new(std::sync::Mutex::new(0));
460
461        let left_result = create_test_from_result(
462            "users",
463            vec![("id", DataType::Integer)],
464            vec![
465                vec![SqlValue::Integer(1)],
466                vec![SqlValue::Integer(2)],
467                vec![SqlValue::Integer(3)],
468                vec![SqlValue::Integer(4)],
469                vec![SqlValue::Integer(5)],
470            ],
471        );
472
473        let counting_iter = CountingIterator {
474            schema: left_result.schema.clone(),
475            rows: left_result.into_rows(),
476            index: 0,
477            consumed_count: consumed.clone(),
478        };
479
480        let right = create_test_from_result(
481            "orders",
482            vec![("user_id", DataType::Integer)],
483            vec![vec![SqlValue::Integer(1)], vec![SqlValue::Integer(2)]],
484        );
485
486        let join_iter = HashJoinIterator::new(counting_iter, right, 0, 0).unwrap();
487
488        // Take only first 2 results
489        let results: Vec<_> = join_iter.take(2).collect::<Result<Vec<_>, _>>().unwrap();
490        assert_eq!(results.len(), 2);
491
492        // Verify that we didn't consume all left rows (lazy evaluation)
493        // We should have consumed at most 2 rows (matching ids 1 and 2)
494        let consumed_count = *consumed.lock().unwrap();
495        assert!(
496            consumed_count <= 3,
497            "Expected at most 3 rows consumed, got {}",
498            consumed_count
499        );
500    }
501}