vibesql_executor/select/join/
hash_join_iterator.rs

1//! Hash join iterator implementation for lazy evaluation
2//!
3//! This module implements an iterator-based hash join that provides O(N+M)
4//! performance while maintaining lazy evaluation for the left (probe) side.
5//!
6//! # Bloom Filter Optimization
7//!
8//! This implementation uses a Bloom filter to quickly reject probe-side rows
9//! that cannot possibly match any build-side rows. This optimization is
10//! inspired by SQLite's `WHERE_BLOOMFILTER` optimization.
11//!
12//! Benefits:
13//! - Bloom filter check is O(1) with excellent cache behavior
14//! - For selective joins, eliminates most hash table lookups
15//! - Memory overhead is small (~10 bits per build-side element)
16
17#![allow(clippy::manual_is_multiple_of)]
18
19use std::hash::{Hash, Hasher};
20
21use ahash::{AHashMap, AHasher};
22
23use super::{combine_rows, BloomFilter, FromResult};
24use crate::{
25    errors::ExecutorError,
26    schema::CombinedSchema,
27    select::RowIterator,
28    timeout::{TimeoutContext, CHECK_INTERVAL},
29};
30
31/// Minimum number of build-side rows to enable Bloom filter optimization.
32/// For small tables, the overhead of maintaining the Bloom filter may exceed benefits.
33const BLOOM_FILTER_MIN_ROWS: usize = 100;
34
35/// Target false positive rate for Bloom filter (1%).
36/// Lower rates require more memory but provide better filtering.
37const BLOOM_FILTER_FPR: f64 = 0.01;
38
39/// Hash a SqlValue for use in Bloom filter.
40/// Uses AHash for fast, high-quality hashing.
41#[inline]
42fn hash_sql_value(value: &vibesql_types::SqlValue) -> u64 {
43    let mut hasher = AHasher::default();
44    value.hash(&mut hasher);
45    hasher.finish()
46}
47
48/// Hash join iterator that lazily produces joined rows
49///
50/// This implementation uses a hash join algorithm with:
51/// - Lazy left (probe) side: rows consumed on-demand from iterator
52/// - Materialized right (build) side: all rows hashed into HashMap
53/// - Bloom filter for quick rejection of non-matching probe rows
54///
55/// Algorithm:
56/// 1. Build phase: Materialize right side into hash table AND Bloom filter (one-time cost)
57/// 2. Probe phase: For each left row: a. Check Bloom filter - if negative, skip immediately (no
58///    hash lookup) b. If Bloom filter positive, do actual hash table lookup
59///
60/// Performance: O(N + M) where N=left rows, M=right rows
61/// For selective joins, Bloom filter reduces constant factor significantly.
62///
63/// Memory: O(M) for right side hash table + O(M * 10 bits) for Bloom filter
64pub struct HashJoinIterator<L: RowIterator> {
65    /// Lazy probe side (left)
66    left: L,
67    /// Materialized build side (right) - hash table mapping join key to rows
68    right_hash_table: AHashMap<vibesql_types::SqlValue, Vec<vibesql_storage::Row>>,
69    /// Bloom filter for quick rejection of non-matching keys (None if disabled)
70    bloom_filter: Option<BloomFilter>,
71    /// Combined schema for output rows
72    schema: CombinedSchema,
73    /// Column index in left table for join key
74    left_col_idx: usize,
75    /// Column index in right table for join key
76    #[allow(dead_code)]
77    right_col_idx: usize,
78    /// Current left row being processed
79    current_left_row: Option<vibesql_storage::Row>,
80    /// Matching right rows for current left row
81    current_matches: Vec<vibesql_storage::Row>,
82    /// Index into current_matches
83    match_index: usize,
84    /// Number of right columns (for NULL padding)
85    #[allow(dead_code)]
86    right_col_count: usize,
87    /// Timeout context for query timeout enforcement
88    timeout_ctx: TimeoutContext,
89    /// Iteration counter for periodic timeout checks
90    iteration_count: usize,
91    /// Count of rows rejected by Bloom filter (for debugging/profiling)
92    #[allow(dead_code)]
93    bloom_rejections: usize,
94}
95
96impl<L: RowIterator> HashJoinIterator<L> {
97    /// Create a new hash join iterator for INNER JOIN
98    ///
99    /// # Arguments
100    /// * `left` - Lazy iterator for left (probe) side
101    /// * `right` - Materialized right (build) side
102    /// * `left_col_idx` - Column index in left table for join key
103    /// * `right_col_idx` - Column index in right table for join key
104    ///
105    /// # Returns
106    /// * `Ok(HashJoinIterator)` - Successfully created iterator
107    /// * `Err(ExecutorError)` - Failed due to memory limits or schema issues
108    #[allow(private_interfaces)]
109    pub fn new(
110        left: L,
111        right: FromResult,
112        left_col_idx: usize,
113        right_col_idx: usize,
114    ) -> Result<Self, ExecutorError> {
115        // Extract right table schema
116        let right_table_name = right
117            .schema
118            .table_schemas
119            .keys()
120            .next()
121            .ok_or_else(|| ExecutorError::UnsupportedFeature("Complex JOIN".to_string()))?
122            .clone();
123
124        let right_schema = right
125            .schema
126            .table_schemas
127            .get(&right_table_name)
128            .ok_or_else(|| ExecutorError::UnsupportedFeature("Complex JOIN".to_string()))?
129            .1
130            .clone();
131
132        let right_col_count = right_schema.columns.len();
133
134        // Get display name from the TableIdentifier
135        let right_table_display_name = right_table_name.display().to_string();
136
137        // Combine schemas (left schema from iterator + right schema)
138        let combined_schema = CombinedSchema::combine(
139            left.schema().clone(),
140            right_table_display_name.clone(),
141            right_schema,
142        );
143
144        // Use default timeout context (proper propagation from SelectExecutor is a future
145        // improvement)
146        let timeout_ctx = TimeoutContext::new_default();
147
148        // Build phase: Create hash table from right side
149        // This is the one-time materialization cost
150        let right_rows = right.into_rows();
151        let num_build_rows = right_rows.len();
152
153        let mut hash_table: AHashMap<vibesql_types::SqlValue, Vec<vibesql_storage::Row>> =
154            AHashMap::new();
155
156        // Create Bloom filter if we have enough rows to benefit
157        // Check environment variable for disabling (useful for A/B testing)
158        let bloom_disabled = std::env::var("VIBESQL_DISABLE_BLOOM_FILTER").is_ok();
159        let mut bloom_filter = if !bloom_disabled && num_build_rows >= BLOOM_FILTER_MIN_ROWS {
160            Some(BloomFilter::new(num_build_rows, BLOOM_FILTER_FPR))
161        } else {
162            None
163        };
164
165        let mut build_iterations = 0;
166
167        for row in right_rows {
168            // Check timeout periodically during build phase
169            build_iterations += 1;
170            if build_iterations % CHECK_INTERVAL == 0 {
171                timeout_ctx.check()?;
172            }
173
174            let key = row.values[right_col_idx].clone();
175
176            // Skip NULL values - they never match in equi-joins
177            if key != vibesql_types::SqlValue::Null {
178                // Insert into Bloom filter before hash table
179                if let Some(ref mut bf) = bloom_filter {
180                    // Hash the SqlValue and insert into Bloom filter
181                    let hash = hash_sql_value(&key);
182                    bf.insert_hash(hash);
183                }
184
185                hash_table.entry(key).or_default().push(row);
186            }
187        }
188
189        Ok(Self {
190            left,
191            right_hash_table: hash_table,
192            bloom_filter,
193            schema: combined_schema,
194            left_col_idx,
195            right_col_idx,
196            current_left_row: None,
197            current_matches: Vec::new(),
198            match_index: 0,
199            right_col_count,
200            timeout_ctx,
201            iteration_count: 0,
202            bloom_rejections: 0,
203        })
204    }
205
206    /// Get the number of rows in the hash table (right side)
207    pub fn hash_table_size(&self) -> usize {
208        self.right_hash_table.values().map(|v| v.len()).sum()
209    }
210}
211
212impl<L: RowIterator> Iterator for HashJoinIterator<L> {
213    type Item = Result<vibesql_storage::Row, ExecutorError>;
214
215    fn next(&mut self) -> Option<Self::Item> {
216        loop {
217            // Check timeout periodically during probe phase
218            self.iteration_count += 1;
219            if self.iteration_count % CHECK_INTERVAL == 0 {
220                if let Err(e) = self.timeout_ctx.check() {
221                    return Some(Err(e));
222                }
223            }
224
225            // If we have remaining matches for current left row, return next match
226            if self.match_index < self.current_matches.len() {
227                let right_row = &self.current_matches[self.match_index];
228                self.match_index += 1;
229
230                // Combine left and right rows
231                if let Some(ref left_row) = self.current_left_row {
232                    let combined_row = combine_rows(left_row, right_row);
233                    return Some(Ok(combined_row));
234                }
235            }
236
237            // No more matches for current left row, get next left row
238            match self.left.next() {
239                Some(Ok(left_row)) => {
240                    let key = &left_row.values[self.left_col_idx];
241
242                    // Skip NULL values - they never match in equi-joins
243                    if key == &vibesql_types::SqlValue::Null {
244                        // For INNER JOIN, skip rows with NULL join keys
245                        continue;
246                    }
247
248                    // BLOOM FILTER OPTIMIZATION:
249                    // Quick check to see if this key MIGHT be in the build side.
250                    // If the Bloom filter says "definitely not present", skip the
251                    // expensive hash table lookup entirely.
252                    if let Some(ref bf) = self.bloom_filter {
253                        let hash = hash_sql_value(key);
254                        if !bf.might_contain_hash(hash) {
255                            // Bloom filter says definitely no match - skip this row
256                            self.bloom_rejections += 1;
257                            continue;
258                        }
259                    }
260
261                    // Lookup matches in hash table
262                    if let Some(matches) = self.right_hash_table.get(key) {
263                        // Found matches - set up for iteration
264                        self.current_left_row = Some(left_row);
265                        self.current_matches = matches.clone();
266                        self.match_index = 0;
267                        // Continue loop to return first match
268                    } else {
269                        // No matches for this left row
270                        // For INNER JOIN, skip this row
271                        // (This can happen due to Bloom filter false positives)
272                        continue;
273                    }
274                }
275                Some(Err(e)) => {
276                    // Propagate error from left iterator
277                    return Some(Err(e));
278                }
279                None => {
280                    // Left iterator exhausted, we're done
281                    return None;
282                }
283            }
284        }
285    }
286}
287
288impl<L: RowIterator> RowIterator for HashJoinIterator<L> {
289    fn schema(&self) -> &CombinedSchema {
290        &self.schema
291    }
292}
293
294#[cfg(test)]
295mod tests {
296    use vibesql_catalog::{ColumnSchema, TableSchema};
297    use vibesql_storage::Row;
298    use vibesql_types::{DataType, SqlValue};
299
300    use super::*;
301    use crate::select::TableScanIterator;
302
303    /// Helper to create a simple FromResult for testing
304    fn create_test_from_result(
305        table_name: &str,
306        columns: Vec<(&str, DataType)>,
307        rows: Vec<Vec<SqlValue>>,
308    ) -> FromResult {
309        let schema = TableSchema::new(
310            table_name.to_string(),
311            columns
312                .iter()
313                .map(|(name, dtype)| {
314                    ColumnSchema::new(
315                        name.to_string(),
316                        dtype.clone(),
317                        true, // nullable
318                    )
319                })
320                .collect(),
321        );
322
323        let combined_schema = CombinedSchema::from_table(table_name.to_string(), schema);
324        let rows = rows.into_iter().map(Row::new).collect();
325
326        FromResult::from_rows(combined_schema, rows)
327    }
328
329    #[test]
330    fn test_hash_join_iterator_simple() {
331        // Left table: users(id, name)
332        let left_result = create_test_from_result(
333            "users",
334            vec![("id", DataType::Integer), ("name", DataType::Varchar { max_length: Some(50) })],
335            vec![
336                vec![SqlValue::Integer(1), SqlValue::Varchar(arcstr::ArcStr::from("Alice"))],
337                vec![SqlValue::Integer(2), SqlValue::Varchar(arcstr::ArcStr::from("Bob"))],
338                vec![SqlValue::Integer(3), SqlValue::Varchar(arcstr::ArcStr::from("Charlie"))],
339            ],
340        );
341
342        let left_iter = TableScanIterator::new(left_result.schema.clone(), left_result.into_rows());
343
344        // Right table: orders(user_id, amount)
345        let right = create_test_from_result(
346            "orders",
347            vec![("user_id", DataType::Integer), ("amount", DataType::Integer)],
348            vec![
349                vec![SqlValue::Integer(1), SqlValue::Integer(100)],
350                vec![SqlValue::Integer(2), SqlValue::Integer(200)],
351                vec![SqlValue::Integer(1), SqlValue::Integer(150)],
352            ],
353        );
354
355        // Join on users.id = orders.user_id (column 0 from both sides)
356        let join_iter = HashJoinIterator::new(left_iter, right, 0, 0).unwrap();
357
358        // Collect results
359        let results: Result<Vec<_>, _> = join_iter.collect();
360        let results = results.unwrap();
361
362        // Should have 3 rows (user 1 has 2 orders, user 2 has 1 order, user 3 has no orders)
363        assert_eq!(results.len(), 3);
364
365        // Verify combined rows have correct structure (4 columns: id, name, user_id, amount)
366        for row in &results {
367            assert_eq!(row.values.len(), 4);
368        }
369
370        // Check specific matches
371        // Alice (id=1) should appear twice (2 orders)
372        let alice_orders: Vec<_> =
373            results.iter().filter(|r| r.values[0] == SqlValue::Integer(1)).collect();
374        assert_eq!(alice_orders.len(), 2);
375
376        // Bob (id=2) should appear once (1 order)
377        let bob_orders: Vec<_> =
378            results.iter().filter(|r| r.values[0] == SqlValue::Integer(2)).collect();
379        assert_eq!(bob_orders.len(), 1);
380
381        // Charlie (id=3) should not appear (no orders)
382        let charlie_orders: Vec<_> =
383            results.iter().filter(|r| r.values[0] == SqlValue::Integer(3)).collect();
384        assert_eq!(charlie_orders.len(), 0);
385    }
386
387    #[test]
388    fn test_hash_join_iterator_null_values() {
389        // Left table with NULL id
390        let left_result = create_test_from_result(
391            "users",
392            vec![("id", DataType::Integer), ("name", DataType::Varchar { max_length: Some(50) })],
393            vec![
394                vec![SqlValue::Integer(1), SqlValue::Varchar(arcstr::ArcStr::from("Alice"))],
395                vec![SqlValue::Null, SqlValue::Varchar(arcstr::ArcStr::from("Unknown"))],
396            ],
397        );
398
399        let left_iter = TableScanIterator::new(left_result.schema.clone(), left_result.into_rows());
400
401        // Right table with NULL user_id
402        let right = create_test_from_result(
403            "orders",
404            vec![("user_id", DataType::Integer), ("amount", DataType::Integer)],
405            vec![
406                vec![SqlValue::Integer(1), SqlValue::Integer(100)],
407                vec![SqlValue::Null, SqlValue::Integer(200)],
408            ],
409        );
410
411        let join_iter = HashJoinIterator::new(left_iter, right, 0, 0).unwrap();
412
413        let results: Result<Vec<_>, _> = join_iter.collect();
414        let results = results.unwrap();
415
416        // Only one match: Alice (id=1) with order (user_id=1)
417        // NULLs should not match each other in equi-joins
418        assert_eq!(results.len(), 1);
419        assert_eq!(results[0].values[0], SqlValue::Integer(1)); // user id
420        assert_eq!(results[0].values[1], SqlValue::Varchar(arcstr::ArcStr::from("Alice"))); // user name
421        assert_eq!(results[0].values[2], SqlValue::Integer(1)); // order user_id
422        assert_eq!(results[0].values[3], SqlValue::Integer(100)); // order amount
423    }
424
425    #[test]
426    fn test_hash_join_iterator_no_matches() {
427        // Left table
428        let left_result = create_test_from_result(
429            "users",
430            vec![("id", DataType::Integer)],
431            vec![vec![SqlValue::Integer(1)], vec![SqlValue::Integer(2)]],
432        );
433
434        let left_iter = TableScanIterator::new(left_result.schema.clone(), left_result.into_rows());
435
436        // Right table with non-matching ids
437        let right = create_test_from_result(
438            "orders",
439            vec![("user_id", DataType::Integer)],
440            vec![vec![SqlValue::Integer(3)], vec![SqlValue::Integer(4)]],
441        );
442
443        let join_iter = HashJoinIterator::new(left_iter, right, 0, 0).unwrap();
444
445        let results: Result<Vec<_>, _> = join_iter.collect();
446        let results = results.unwrap();
447
448        // No matches
449        assert_eq!(results.len(), 0);
450    }
451
452    #[test]
453    fn test_hash_join_iterator_empty_tables() {
454        // Left table (empty)
455        let left_result = create_test_from_result("users", vec![("id", DataType::Integer)], vec![]);
456
457        let left_iter = TableScanIterator::new(left_result.schema.clone(), left_result.into_rows());
458
459        // Right table (empty)
460        let right = create_test_from_result("orders", vec![("user_id", DataType::Integer)], vec![]);
461
462        let join_iter = HashJoinIterator::new(left_iter, right, 0, 0).unwrap();
463
464        let results: Result<Vec<_>, _> = join_iter.collect();
465        let results = results.unwrap();
466
467        // No rows
468        assert_eq!(results.len(), 0);
469    }
470
471    #[test]
472    fn test_hash_join_iterator_duplicate_keys() {
473        // Left table with duplicate ids
474        let left_result = create_test_from_result(
475            "users",
476            vec![("id", DataType::Integer), ("type", DataType::Varchar { max_length: Some(10) })],
477            vec![
478                vec![SqlValue::Integer(1), SqlValue::Varchar(arcstr::ArcStr::from("admin"))],
479                vec![SqlValue::Integer(1), SqlValue::Varchar(arcstr::ArcStr::from("user"))],
480            ],
481        );
482
483        let left_iter = TableScanIterator::new(left_result.schema.clone(), left_result.into_rows());
484
485        // Right table with duplicate user_ids
486        let right = create_test_from_result(
487            "orders",
488            vec![("user_id", DataType::Integer), ("amount", DataType::Integer)],
489            vec![
490                vec![SqlValue::Integer(1), SqlValue::Integer(100)],
491                vec![SqlValue::Integer(1), SqlValue::Integer(200)],
492            ],
493        );
494
495        let join_iter = HashJoinIterator::new(left_iter, right, 0, 0).unwrap();
496
497        let results: Result<Vec<_>, _> = join_iter.collect();
498        let results = results.unwrap();
499
500        // Cartesian product of matching keys: 2 left rows * 2 right rows = 4 results
501        assert_eq!(results.len(), 4);
502
503        // All should have id=1
504        for row in &results {
505            assert_eq!(row.values[0], SqlValue::Integer(1));
506        }
507    }
508
509    #[test]
510    fn test_hash_join_iterator_lazy_evaluation() {
511        // This test verifies that the left side is truly lazy
512        // We'll create an iterator that tracks how many rows have been consumed
513
514        struct CountingIterator {
515            schema: CombinedSchema,
516            rows: Vec<Row>,
517            index: usize,
518            consumed_count: std::sync::Arc<std::sync::Mutex<usize>>,
519        }
520
521        impl Iterator for CountingIterator {
522            type Item = Result<Row, ExecutorError>;
523
524            fn next(&mut self) -> Option<Self::Item> {
525                if self.index < self.rows.len() {
526                    let row = self.rows[self.index].clone();
527                    self.index += 1;
528                    *self.consumed_count.lock().unwrap() += 1;
529                    Some(Ok(row))
530                } else {
531                    None
532                }
533            }
534        }
535
536        impl RowIterator for CountingIterator {
537            fn schema(&self) -> &CombinedSchema {
538                &self.schema
539            }
540        }
541
542        let consumed = std::sync::Arc::new(std::sync::Mutex::new(0));
543
544        let left_result = create_test_from_result(
545            "users",
546            vec![("id", DataType::Integer)],
547            vec![
548                vec![SqlValue::Integer(1)],
549                vec![SqlValue::Integer(2)],
550                vec![SqlValue::Integer(3)],
551                vec![SqlValue::Integer(4)],
552                vec![SqlValue::Integer(5)],
553            ],
554        );
555
556        let counting_iter = CountingIterator {
557            schema: left_result.schema.clone(),
558            rows: left_result.into_rows(),
559            index: 0,
560            consumed_count: consumed.clone(),
561        };
562
563        let right = create_test_from_result(
564            "orders",
565            vec![("user_id", DataType::Integer)],
566            vec![vec![SqlValue::Integer(1)], vec![SqlValue::Integer(2)]],
567        );
568
569        let join_iter = HashJoinIterator::new(counting_iter, right, 0, 0).unwrap();
570
571        // Take only first 2 results
572        let results: Vec<_> = join_iter.take(2).collect::<Result<Vec<_>, _>>().unwrap();
573        assert_eq!(results.len(), 2);
574
575        // Verify that we didn't consume all left rows (lazy evaluation)
576        // We should have consumed at most 2 rows (matching ids 1 and 2)
577        let consumed_count = *consumed.lock().unwrap();
578        assert!(consumed_count <= 3, "Expected at most 3 rows consumed, got {}", consumed_count);
579    }
580}