vibesql_executor/select/join/
hash_join_iterator.rs

1//! Hash join iterator implementation for lazy evaluation
2//!
3//! This module implements an iterator-based hash join that provides O(N+M)
4//! performance while maintaining lazy evaluation for the left (probe) side.
5//!
6//! # Bloom Filter Optimization
7//!
8//! This implementation uses a Bloom filter to quickly reject probe-side rows
9//! that cannot possibly match any build-side rows. This optimization is
10//! inspired by SQLite's `WHERE_BLOOMFILTER` optimization.
11//!
12//! Benefits:
13//! - Bloom filter check is O(1) with excellent cache behavior
14//! - For selective joins, eliminates most hash table lookups
15//! - Memory overhead is small (~10 bits per build-side element)
16
17#![allow(clippy::manual_is_multiple_of)]
18
19use std::hash::{Hash, Hasher};
20
21use ahash::{AHashMap, AHasher};
22
23use super::{combine_rows, BloomFilter, FromResult};
24use crate::{
25    errors::ExecutorError,
26    schema::CombinedSchema,
27    select::RowIterator,
28    timeout::{TimeoutContext, CHECK_INTERVAL},
29};
30
31/// Minimum number of build-side rows to enable Bloom filter optimization.
32/// For small tables, the overhead of maintaining the Bloom filter may exceed benefits.
33const BLOOM_FILTER_MIN_ROWS: usize = 100;
34
35/// Target false positive rate for Bloom filter (1%).
36/// Lower rates require more memory but provide better filtering.
37const BLOOM_FILTER_FPR: f64 = 0.01;
38
39/// Hash a SqlValue for use in Bloom filter.
40/// Uses AHash for fast, high-quality hashing.
41#[inline]
42fn hash_sql_value(value: &vibesql_types::SqlValue) -> u64 {
43    let mut hasher = AHasher::default();
44    value.hash(&mut hasher);
45    hasher.finish()
46}
47
48/// Hash join iterator that lazily produces joined rows
49///
50/// This implementation uses a hash join algorithm with:
51/// - Lazy left (probe) side: rows consumed on-demand from iterator
52/// - Materialized right (build) side: all rows hashed into HashMap
53/// - Bloom filter for quick rejection of non-matching probe rows
54///
55/// Algorithm:
56/// 1. Build phase: Materialize right side into hash table AND Bloom filter (one-time cost)
57/// 2. Probe phase: For each left row:
58///    a. Check Bloom filter - if negative, skip immediately (no hash lookup)
59///    b. If Bloom filter positive, do actual hash table lookup
60///
61/// Performance: O(N + M) where N=left rows, M=right rows
62/// For selective joins, Bloom filter reduces constant factor significantly.
63///
64/// Memory: O(M) for right side hash table + O(M * 10 bits) for Bloom filter
65pub struct HashJoinIterator<L: RowIterator> {
66    /// Lazy probe side (left)
67    left: L,
68    /// Materialized build side (right) - hash table mapping join key to rows
69    right_hash_table: AHashMap<vibesql_types::SqlValue, Vec<vibesql_storage::Row>>,
70    /// Bloom filter for quick rejection of non-matching keys (None if disabled)
71    bloom_filter: Option<BloomFilter>,
72    /// Combined schema for output rows
73    schema: CombinedSchema,
74    /// Column index in left table for join key
75    left_col_idx: usize,
76    /// Column index in right table for join key
77    #[allow(dead_code)]
78    right_col_idx: usize,
79    /// Current left row being processed
80    current_left_row: Option<vibesql_storage::Row>,
81    /// Matching right rows for current left row
82    current_matches: Vec<vibesql_storage::Row>,
83    /// Index into current_matches
84    match_index: usize,
85    /// Number of right columns (for NULL padding)
86    #[allow(dead_code)]
87    right_col_count: usize,
88    /// Timeout context for query timeout enforcement
89    timeout_ctx: TimeoutContext,
90    /// Iteration counter for periodic timeout checks
91    iteration_count: usize,
92    /// Count of rows rejected by Bloom filter (for debugging/profiling)
93    #[allow(dead_code)]
94    bloom_rejections: usize,
95}
96
97impl<L: RowIterator> HashJoinIterator<L> {
98    /// Create a new hash join iterator for INNER JOIN
99    ///
100    /// # Arguments
101    /// * `left` - Lazy iterator for left (probe) side
102    /// * `right` - Materialized right (build) side
103    /// * `left_col_idx` - Column index in left table for join key
104    /// * `right_col_idx` - Column index in right table for join key
105    ///
106    /// # Returns
107    /// * `Ok(HashJoinIterator)` - Successfully created iterator
108    /// * `Err(ExecutorError)` - Failed due to memory limits or schema issues
109    #[allow(private_interfaces)]
110    pub fn new(
111        left: L,
112        right: FromResult,
113        left_col_idx: usize,
114        right_col_idx: usize,
115    ) -> Result<Self, ExecutorError> {
116        // Extract right table schema
117        let right_table_name = right
118            .schema
119            .table_schemas
120            .keys()
121            .next()
122            .ok_or_else(|| ExecutorError::UnsupportedFeature("Complex JOIN".to_string()))?
123            .clone();
124
125        let right_schema = right
126            .schema
127            .table_schemas
128            .get(&right_table_name)
129            .ok_or_else(|| ExecutorError::UnsupportedFeature("Complex JOIN".to_string()))?
130            .1
131            .clone();
132
133        let right_col_count = right_schema.columns.len();
134
135        // Combine schemas (left schema from iterator + right schema)
136        let combined_schema =
137            CombinedSchema::combine(left.schema().clone(), right_table_name, right_schema);
138
139        // Use default timeout context (proper propagation from SelectExecutor is a future improvement)
140        let timeout_ctx = TimeoutContext::new_default();
141
142        // Build phase: Create hash table from right side
143        // This is the one-time materialization cost
144        let right_rows = right.into_rows();
145        let num_build_rows = right_rows.len();
146
147        let mut hash_table: AHashMap<vibesql_types::SqlValue, Vec<vibesql_storage::Row>> =
148            AHashMap::new();
149
150        // Create Bloom filter if we have enough rows to benefit
151        // Check environment variable for disabling (useful for A/B testing)
152        let bloom_disabled = std::env::var("VIBESQL_DISABLE_BLOOM_FILTER").is_ok();
153        let mut bloom_filter = if !bloom_disabled && num_build_rows >= BLOOM_FILTER_MIN_ROWS {
154            Some(BloomFilter::new(num_build_rows, BLOOM_FILTER_FPR))
155        } else {
156            None
157        };
158
159        let mut build_iterations = 0;
160
161        for row in right_rows {
162            // Check timeout periodically during build phase
163            build_iterations += 1;
164            if build_iterations % CHECK_INTERVAL == 0 {
165                timeout_ctx.check()?;
166            }
167
168            let key = row.values[right_col_idx].clone();
169
170            // Skip NULL values - they never match in equi-joins
171            if key != vibesql_types::SqlValue::Null {
172                // Insert into Bloom filter before hash table
173                if let Some(ref mut bf) = bloom_filter {
174                    // Hash the SqlValue and insert into Bloom filter
175                    let hash = hash_sql_value(&key);
176                    bf.insert_hash(hash);
177                }
178
179                hash_table.entry(key).or_default().push(row);
180            }
181        }
182
183        Ok(Self {
184            left,
185            right_hash_table: hash_table,
186            bloom_filter,
187            schema: combined_schema,
188            left_col_idx,
189            right_col_idx,
190            current_left_row: None,
191            current_matches: Vec::new(),
192            match_index: 0,
193            right_col_count,
194            timeout_ctx,
195            iteration_count: 0,
196            bloom_rejections: 0,
197        })
198    }
199
200    /// Get the number of rows in the hash table (right side)
201    pub fn hash_table_size(&self) -> usize {
202        self.right_hash_table.values().map(|v| v.len()).sum()
203    }
204}
205
206impl<L: RowIterator> Iterator for HashJoinIterator<L> {
207    type Item = Result<vibesql_storage::Row, ExecutorError>;
208
209    fn next(&mut self) -> Option<Self::Item> {
210        loop {
211            // Check timeout periodically during probe phase
212            self.iteration_count += 1;
213            if self.iteration_count % CHECK_INTERVAL == 0 {
214                if let Err(e) = self.timeout_ctx.check() {
215                    return Some(Err(e));
216                }
217            }
218
219            // If we have remaining matches for current left row, return next match
220            if self.match_index < self.current_matches.len() {
221                let right_row = &self.current_matches[self.match_index];
222                self.match_index += 1;
223
224                // Combine left and right rows
225                if let Some(ref left_row) = self.current_left_row {
226                    let combined_row = combine_rows(left_row, right_row);
227                    return Some(Ok(combined_row));
228                }
229            }
230
231            // No more matches for current left row, get next left row
232            match self.left.next() {
233                Some(Ok(left_row)) => {
234                    let key = &left_row.values[self.left_col_idx];
235
236                    // Skip NULL values - they never match in equi-joins
237                    if key == &vibesql_types::SqlValue::Null {
238                        // For INNER JOIN, skip rows with NULL join keys
239                        continue;
240                    }
241
242                    // BLOOM FILTER OPTIMIZATION:
243                    // Quick check to see if this key MIGHT be in the build side.
244                    // If the Bloom filter says "definitely not present", skip the
245                    // expensive hash table lookup entirely.
246                    if let Some(ref bf) = self.bloom_filter {
247                        let hash = hash_sql_value(key);
248                        if !bf.might_contain_hash(hash) {
249                            // Bloom filter says definitely no match - skip this row
250                            self.bloom_rejections += 1;
251                            continue;
252                        }
253                    }
254
255                    // Lookup matches in hash table
256                    if let Some(matches) = self.right_hash_table.get(key) {
257                        // Found matches - set up for iteration
258                        self.current_left_row = Some(left_row);
259                        self.current_matches = matches.clone();
260                        self.match_index = 0;
261                        // Continue loop to return first match
262                    } else {
263                        // No matches for this left row
264                        // For INNER JOIN, skip this row
265                        // (This can happen due to Bloom filter false positives)
266                        continue;
267                    }
268                }
269                Some(Err(e)) => {
270                    // Propagate error from left iterator
271                    return Some(Err(e));
272                }
273                None => {
274                    // Left iterator exhausted, we're done
275                    return None;
276                }
277            }
278        }
279    }
280}
281
282impl<L: RowIterator> RowIterator for HashJoinIterator<L> {
283    fn schema(&self) -> &CombinedSchema {
284        &self.schema
285    }
286}
287
288#[cfg(test)]
289mod tests {
290    use super::*;
291    use crate::select::TableScanIterator;
292    use vibesql_catalog::{ColumnSchema, TableSchema};
293    use vibesql_storage::Row;
294    use vibesql_types::{DataType, SqlValue};
295
296    /// Helper to create a simple FromResult for testing
297    fn create_test_from_result(
298        table_name: &str,
299        columns: Vec<(&str, DataType)>,
300        rows: Vec<Vec<SqlValue>>,
301    ) -> FromResult {
302        let schema = TableSchema::new(
303            table_name.to_string(),
304            columns
305                .iter()
306                .map(|(name, dtype)| {
307                    ColumnSchema::new(
308                        name.to_string(),
309                        dtype.clone(),
310                        true, // nullable
311                    )
312                })
313                .collect(),
314        );
315
316        let combined_schema = CombinedSchema::from_table(table_name.to_string(), schema);
317        let rows = rows.into_iter().map(Row::new).collect();
318
319        FromResult::from_rows(combined_schema, rows)
320    }
321
322    #[test]
323    fn test_hash_join_iterator_simple() {
324        // Left table: users(id, name)
325        let left_result = create_test_from_result(
326            "users",
327            vec![("id", DataType::Integer), ("name", DataType::Varchar { max_length: Some(50) })],
328            vec![
329                vec![SqlValue::Integer(1), SqlValue::Varchar(arcstr::ArcStr::from("Alice"))],
330                vec![SqlValue::Integer(2), SqlValue::Varchar(arcstr::ArcStr::from("Bob"))],
331                vec![SqlValue::Integer(3), SqlValue::Varchar(arcstr::ArcStr::from("Charlie"))],
332            ],
333        );
334
335        let left_iter = TableScanIterator::new(left_result.schema.clone(), left_result.into_rows());
336
337        // Right table: orders(user_id, amount)
338        let right = create_test_from_result(
339            "orders",
340            vec![("user_id", DataType::Integer), ("amount", DataType::Integer)],
341            vec![
342                vec![SqlValue::Integer(1), SqlValue::Integer(100)],
343                vec![SqlValue::Integer(2), SqlValue::Integer(200)],
344                vec![SqlValue::Integer(1), SqlValue::Integer(150)],
345            ],
346        );
347
348        // Join on users.id = orders.user_id (column 0 from both sides)
349        let join_iter = HashJoinIterator::new(left_iter, right, 0, 0).unwrap();
350
351        // Collect results
352        let results: Result<Vec<_>, _> = join_iter.collect();
353        let results = results.unwrap();
354
355        // Should have 3 rows (user 1 has 2 orders, user 2 has 1 order, user 3 has no orders)
356        assert_eq!(results.len(), 3);
357
358        // Verify combined rows have correct structure (4 columns: id, name, user_id, amount)
359        for row in &results {
360            assert_eq!(row.values.len(), 4);
361        }
362
363        // Check specific matches
364        // Alice (id=1) should appear twice (2 orders)
365        let alice_orders: Vec<_> =
366            results.iter().filter(|r| r.values[0] == SqlValue::Integer(1)).collect();
367        assert_eq!(alice_orders.len(), 2);
368
369        // Bob (id=2) should appear once (1 order)
370        let bob_orders: Vec<_> =
371            results.iter().filter(|r| r.values[0] == SqlValue::Integer(2)).collect();
372        assert_eq!(bob_orders.len(), 1);
373
374        // Charlie (id=3) should not appear (no orders)
375        let charlie_orders: Vec<_> =
376            results.iter().filter(|r| r.values[0] == SqlValue::Integer(3)).collect();
377        assert_eq!(charlie_orders.len(), 0);
378    }
379
380    #[test]
381    fn test_hash_join_iterator_null_values() {
382        // Left table with NULL id
383        let left_result = create_test_from_result(
384            "users",
385            vec![("id", DataType::Integer), ("name", DataType::Varchar { max_length: Some(50) })],
386            vec![
387                vec![SqlValue::Integer(1), SqlValue::Varchar(arcstr::ArcStr::from("Alice"))],
388                vec![SqlValue::Null, SqlValue::Varchar(arcstr::ArcStr::from("Unknown"))],
389            ],
390        );
391
392        let left_iter = TableScanIterator::new(left_result.schema.clone(), left_result.into_rows());
393
394        // Right table with NULL user_id
395        let right = create_test_from_result(
396            "orders",
397            vec![("user_id", DataType::Integer), ("amount", DataType::Integer)],
398            vec![
399                vec![SqlValue::Integer(1), SqlValue::Integer(100)],
400                vec![SqlValue::Null, SqlValue::Integer(200)],
401            ],
402        );
403
404        let join_iter = HashJoinIterator::new(left_iter, right, 0, 0).unwrap();
405
406        let results: Result<Vec<_>, _> = join_iter.collect();
407        let results = results.unwrap();
408
409        // Only one match: Alice (id=1) with order (user_id=1)
410        // NULLs should not match each other in equi-joins
411        assert_eq!(results.len(), 1);
412        assert_eq!(results[0].values[0], SqlValue::Integer(1)); // user id
413        assert_eq!(results[0].values[1], SqlValue::Varchar(arcstr::ArcStr::from("Alice"))); // user name
414        assert_eq!(results[0].values[2], SqlValue::Integer(1)); // order user_id
415        assert_eq!(results[0].values[3], SqlValue::Integer(100)); // order amount
416    }
417
418    #[test]
419    fn test_hash_join_iterator_no_matches() {
420        // Left table
421        let left_result = create_test_from_result(
422            "users",
423            vec![("id", DataType::Integer)],
424            vec![vec![SqlValue::Integer(1)], vec![SqlValue::Integer(2)]],
425        );
426
427        let left_iter = TableScanIterator::new(left_result.schema.clone(), left_result.into_rows());
428
429        // Right table with non-matching ids
430        let right = create_test_from_result(
431            "orders",
432            vec![("user_id", DataType::Integer)],
433            vec![vec![SqlValue::Integer(3)], vec![SqlValue::Integer(4)]],
434        );
435
436        let join_iter = HashJoinIterator::new(left_iter, right, 0, 0).unwrap();
437
438        let results: Result<Vec<_>, _> = join_iter.collect();
439        let results = results.unwrap();
440
441        // No matches
442        assert_eq!(results.len(), 0);
443    }
444
445    #[test]
446    fn test_hash_join_iterator_empty_tables() {
447        // Left table (empty)
448        let left_result = create_test_from_result("users", vec![("id", DataType::Integer)], vec![]);
449
450        let left_iter = TableScanIterator::new(left_result.schema.clone(), left_result.into_rows());
451
452        // Right table (empty)
453        let right = create_test_from_result("orders", vec![("user_id", DataType::Integer)], vec![]);
454
455        let join_iter = HashJoinIterator::new(left_iter, right, 0, 0).unwrap();
456
457        let results: Result<Vec<_>, _> = join_iter.collect();
458        let results = results.unwrap();
459
460        // No rows
461        assert_eq!(results.len(), 0);
462    }
463
464    #[test]
465    fn test_hash_join_iterator_duplicate_keys() {
466        // Left table with duplicate ids
467        let left_result = create_test_from_result(
468            "users",
469            vec![("id", DataType::Integer), ("type", DataType::Varchar { max_length: Some(10) })],
470            vec![
471                vec![SqlValue::Integer(1), SqlValue::Varchar(arcstr::ArcStr::from("admin"))],
472                vec![SqlValue::Integer(1), SqlValue::Varchar(arcstr::ArcStr::from("user"))],
473            ],
474        );
475
476        let left_iter = TableScanIterator::new(left_result.schema.clone(), left_result.into_rows());
477
478        // Right table with duplicate user_ids
479        let right = create_test_from_result(
480            "orders",
481            vec![("user_id", DataType::Integer), ("amount", DataType::Integer)],
482            vec![
483                vec![SqlValue::Integer(1), SqlValue::Integer(100)],
484                vec![SqlValue::Integer(1), SqlValue::Integer(200)],
485            ],
486        );
487
488        let join_iter = HashJoinIterator::new(left_iter, right, 0, 0).unwrap();
489
490        let results: Result<Vec<_>, _> = join_iter.collect();
491        let results = results.unwrap();
492
493        // Cartesian product of matching keys: 2 left rows * 2 right rows = 4 results
494        assert_eq!(results.len(), 4);
495
496        // All should have id=1
497        for row in &results {
498            assert_eq!(row.values[0], SqlValue::Integer(1));
499        }
500    }
501
502    #[test]
503    fn test_hash_join_iterator_lazy_evaluation() {
504        // This test verifies that the left side is truly lazy
505        // We'll create an iterator that tracks how many rows have been consumed
506
507        struct CountingIterator {
508            schema: CombinedSchema,
509            rows: Vec<Row>,
510            index: usize,
511            consumed_count: std::sync::Arc<std::sync::Mutex<usize>>,
512        }
513
514        impl Iterator for CountingIterator {
515            type Item = Result<Row, ExecutorError>;
516
517            fn next(&mut self) -> Option<Self::Item> {
518                if self.index < self.rows.len() {
519                    let row = self.rows[self.index].clone();
520                    self.index += 1;
521                    *self.consumed_count.lock().unwrap() += 1;
522                    Some(Ok(row))
523                } else {
524                    None
525                }
526            }
527        }
528
529        impl RowIterator for CountingIterator {
530            fn schema(&self) -> &CombinedSchema {
531                &self.schema
532            }
533        }
534
535        let consumed = std::sync::Arc::new(std::sync::Mutex::new(0));
536
537        let left_result = create_test_from_result(
538            "users",
539            vec![("id", DataType::Integer)],
540            vec![
541                vec![SqlValue::Integer(1)],
542                vec![SqlValue::Integer(2)],
543                vec![SqlValue::Integer(3)],
544                vec![SqlValue::Integer(4)],
545                vec![SqlValue::Integer(5)],
546            ],
547        );
548
549        let counting_iter = CountingIterator {
550            schema: left_result.schema.clone(),
551            rows: left_result.into_rows(),
552            index: 0,
553            consumed_count: consumed.clone(),
554        };
555
556        let right = create_test_from_result(
557            "orders",
558            vec![("user_id", DataType::Integer)],
559            vec![vec![SqlValue::Integer(1)], vec![SqlValue::Integer(2)]],
560        );
561
562        let join_iter = HashJoinIterator::new(counting_iter, right, 0, 0).unwrap();
563
564        // Take only first 2 results
565        let results: Vec<_> = join_iter.take(2).collect::<Result<Vec<_>, _>>().unwrap();
566        assert_eq!(results.len(), 2);
567
568        // Verify that we didn't consume all left rows (lazy evaluation)
569        // We should have consumed at most 2 rows (matching ids 1 and 2)
570        let consumed_count = *consumed.lock().unwrap();
571        assert!(consumed_count <= 3, "Expected at most 3 rows consumed, got {}", consumed_count);
572    }
573}