1#![allow(clippy::manual_is_multiple_of)]
18
19use std::hash::{Hash, Hasher};
20
21use ahash::{AHashMap, AHasher};
22
23use super::{combine_rows, BloomFilter, FromResult};
24use crate::{
25 errors::ExecutorError,
26 schema::CombinedSchema,
27 select::RowIterator,
28 timeout::{TimeoutContext, CHECK_INTERVAL},
29};
30
31const BLOOM_FILTER_MIN_ROWS: usize = 100;
34
35const BLOOM_FILTER_FPR: f64 = 0.01;
38
39#[inline]
42fn hash_sql_value(value: &vibesql_types::SqlValue) -> u64 {
43 let mut hasher = AHasher::default();
44 value.hash(&mut hasher);
45 hasher.finish()
46}
47
48pub struct HashJoinIterator<L: RowIterator> {
65 left: L,
67 right_hash_table: AHashMap<vibesql_types::SqlValue, Vec<vibesql_storage::Row>>,
69 bloom_filter: Option<BloomFilter>,
71 schema: CombinedSchema,
73 left_col_idx: usize,
75 #[allow(dead_code)]
77 right_col_idx: usize,
78 current_left_row: Option<vibesql_storage::Row>,
80 current_matches: Vec<vibesql_storage::Row>,
82 match_index: usize,
84 #[allow(dead_code)]
86 right_col_count: usize,
87 timeout_ctx: TimeoutContext,
89 iteration_count: usize,
91 #[allow(dead_code)]
93 bloom_rejections: usize,
94}
95
96impl<L: RowIterator> HashJoinIterator<L> {
97 #[allow(private_interfaces)]
109 pub fn new(
110 left: L,
111 right: FromResult,
112 left_col_idx: usize,
113 right_col_idx: usize,
114 ) -> Result<Self, ExecutorError> {
115 let right_table_name = right
117 .schema
118 .table_schemas
119 .keys()
120 .next()
121 .ok_or_else(|| ExecutorError::UnsupportedFeature("Complex JOIN".to_string()))?
122 .clone();
123
124 let right_schema = right
125 .schema
126 .table_schemas
127 .get(&right_table_name)
128 .ok_or_else(|| ExecutorError::UnsupportedFeature("Complex JOIN".to_string()))?
129 .1
130 .clone();
131
132 let right_col_count = right_schema.columns.len();
133
134 let right_table_display_name = right_table_name.display().to_string();
136
137 let combined_schema = CombinedSchema::combine(
139 left.schema().clone(),
140 right_table_display_name.clone(),
141 right_schema,
142 );
143
144 let timeout_ctx = TimeoutContext::new_default();
147
148 let right_rows = right.into_rows();
151 let num_build_rows = right_rows.len();
152
153 let mut hash_table: AHashMap<vibesql_types::SqlValue, Vec<vibesql_storage::Row>> =
154 AHashMap::new();
155
156 let bloom_disabled = std::env::var("VIBESQL_DISABLE_BLOOM_FILTER").is_ok();
159 let mut bloom_filter = if !bloom_disabled && num_build_rows >= BLOOM_FILTER_MIN_ROWS {
160 Some(BloomFilter::new(num_build_rows, BLOOM_FILTER_FPR))
161 } else {
162 None
163 };
164
165 let mut build_iterations = 0;
166
167 for row in right_rows {
168 build_iterations += 1;
170 if build_iterations % CHECK_INTERVAL == 0 {
171 timeout_ctx.check()?;
172 }
173
174 let key = row.values[right_col_idx].clone();
175
176 if key != vibesql_types::SqlValue::Null {
178 if let Some(ref mut bf) = bloom_filter {
180 let hash = hash_sql_value(&key);
182 bf.insert_hash(hash);
183 }
184
185 hash_table.entry(key).or_default().push(row);
186 }
187 }
188
189 Ok(Self {
190 left,
191 right_hash_table: hash_table,
192 bloom_filter,
193 schema: combined_schema,
194 left_col_idx,
195 right_col_idx,
196 current_left_row: None,
197 current_matches: Vec::new(),
198 match_index: 0,
199 right_col_count,
200 timeout_ctx,
201 iteration_count: 0,
202 bloom_rejections: 0,
203 })
204 }
205
206 pub fn hash_table_size(&self) -> usize {
208 self.right_hash_table.values().map(|v| v.len()).sum()
209 }
210}
211
212impl<L: RowIterator> Iterator for HashJoinIterator<L> {
213 type Item = Result<vibesql_storage::Row, ExecutorError>;
214
215 fn next(&mut self) -> Option<Self::Item> {
216 loop {
217 self.iteration_count += 1;
219 if self.iteration_count % CHECK_INTERVAL == 0 {
220 if let Err(e) = self.timeout_ctx.check() {
221 return Some(Err(e));
222 }
223 }
224
225 if self.match_index < self.current_matches.len() {
227 let right_row = &self.current_matches[self.match_index];
228 self.match_index += 1;
229
230 if let Some(ref left_row) = self.current_left_row {
232 let combined_row = combine_rows(left_row, right_row);
233 return Some(Ok(combined_row));
234 }
235 }
236
237 match self.left.next() {
239 Some(Ok(left_row)) => {
240 let key = &left_row.values[self.left_col_idx];
241
242 if key == &vibesql_types::SqlValue::Null {
244 continue;
246 }
247
248 if let Some(ref bf) = self.bloom_filter {
253 let hash = hash_sql_value(key);
254 if !bf.might_contain_hash(hash) {
255 self.bloom_rejections += 1;
257 continue;
258 }
259 }
260
261 if let Some(matches) = self.right_hash_table.get(key) {
263 self.current_left_row = Some(left_row);
265 self.current_matches = matches.clone();
266 self.match_index = 0;
267 } else {
269 continue;
273 }
274 }
275 Some(Err(e)) => {
276 return Some(Err(e));
278 }
279 None => {
280 return None;
282 }
283 }
284 }
285 }
286}
287
288impl<L: RowIterator> RowIterator for HashJoinIterator<L> {
289 fn schema(&self) -> &CombinedSchema {
290 &self.schema
291 }
292}
293
294#[cfg(test)]
295mod tests {
296 use vibesql_catalog::{ColumnSchema, TableSchema};
297 use vibesql_storage::Row;
298 use vibesql_types::{DataType, SqlValue};
299
300 use super::*;
301 use crate::select::TableScanIterator;
302
303 fn create_test_from_result(
305 table_name: &str,
306 columns: Vec<(&str, DataType)>,
307 rows: Vec<Vec<SqlValue>>,
308 ) -> FromResult {
309 let schema = TableSchema::new(
310 table_name.to_string(),
311 columns
312 .iter()
313 .map(|(name, dtype)| {
314 ColumnSchema::new(
315 name.to_string(),
316 dtype.clone(),
317 true, )
319 })
320 .collect(),
321 );
322
323 let combined_schema = CombinedSchema::from_table(table_name.to_string(), schema);
324 let rows = rows.into_iter().map(Row::new).collect();
325
326 FromResult::from_rows(combined_schema, rows)
327 }
328
329 #[test]
330 fn test_hash_join_iterator_simple() {
331 let left_result = create_test_from_result(
333 "users",
334 vec![("id", DataType::Integer), ("name", DataType::Varchar { max_length: Some(50) })],
335 vec![
336 vec![SqlValue::Integer(1), SqlValue::Varchar(arcstr::ArcStr::from("Alice"))],
337 vec![SqlValue::Integer(2), SqlValue::Varchar(arcstr::ArcStr::from("Bob"))],
338 vec![SqlValue::Integer(3), SqlValue::Varchar(arcstr::ArcStr::from("Charlie"))],
339 ],
340 );
341
342 let left_iter = TableScanIterator::new(left_result.schema.clone(), left_result.into_rows());
343
344 let right = create_test_from_result(
346 "orders",
347 vec![("user_id", DataType::Integer), ("amount", DataType::Integer)],
348 vec![
349 vec![SqlValue::Integer(1), SqlValue::Integer(100)],
350 vec![SqlValue::Integer(2), SqlValue::Integer(200)],
351 vec![SqlValue::Integer(1), SqlValue::Integer(150)],
352 ],
353 );
354
355 let join_iter = HashJoinIterator::new(left_iter, right, 0, 0).unwrap();
357
358 let results: Result<Vec<_>, _> = join_iter.collect();
360 let results = results.unwrap();
361
362 assert_eq!(results.len(), 3);
364
365 for row in &results {
367 assert_eq!(row.values.len(), 4);
368 }
369
370 let alice_orders: Vec<_> =
373 results.iter().filter(|r| r.values[0] == SqlValue::Integer(1)).collect();
374 assert_eq!(alice_orders.len(), 2);
375
376 let bob_orders: Vec<_> =
378 results.iter().filter(|r| r.values[0] == SqlValue::Integer(2)).collect();
379 assert_eq!(bob_orders.len(), 1);
380
381 let charlie_orders: Vec<_> =
383 results.iter().filter(|r| r.values[0] == SqlValue::Integer(3)).collect();
384 assert_eq!(charlie_orders.len(), 0);
385 }
386
387 #[test]
388 fn test_hash_join_iterator_null_values() {
389 let left_result = create_test_from_result(
391 "users",
392 vec![("id", DataType::Integer), ("name", DataType::Varchar { max_length: Some(50) })],
393 vec![
394 vec![SqlValue::Integer(1), SqlValue::Varchar(arcstr::ArcStr::from("Alice"))],
395 vec![SqlValue::Null, SqlValue::Varchar(arcstr::ArcStr::from("Unknown"))],
396 ],
397 );
398
399 let left_iter = TableScanIterator::new(left_result.schema.clone(), left_result.into_rows());
400
401 let right = create_test_from_result(
403 "orders",
404 vec![("user_id", DataType::Integer), ("amount", DataType::Integer)],
405 vec![
406 vec![SqlValue::Integer(1), SqlValue::Integer(100)],
407 vec![SqlValue::Null, SqlValue::Integer(200)],
408 ],
409 );
410
411 let join_iter = HashJoinIterator::new(left_iter, right, 0, 0).unwrap();
412
413 let results: Result<Vec<_>, _> = join_iter.collect();
414 let results = results.unwrap();
415
416 assert_eq!(results.len(), 1);
419 assert_eq!(results[0].values[0], SqlValue::Integer(1)); assert_eq!(results[0].values[1], SqlValue::Varchar(arcstr::ArcStr::from("Alice"))); assert_eq!(results[0].values[2], SqlValue::Integer(1)); assert_eq!(results[0].values[3], SqlValue::Integer(100)); }
424
425 #[test]
426 fn test_hash_join_iterator_no_matches() {
427 let left_result = create_test_from_result(
429 "users",
430 vec![("id", DataType::Integer)],
431 vec![vec![SqlValue::Integer(1)], vec![SqlValue::Integer(2)]],
432 );
433
434 let left_iter = TableScanIterator::new(left_result.schema.clone(), left_result.into_rows());
435
436 let right = create_test_from_result(
438 "orders",
439 vec![("user_id", DataType::Integer)],
440 vec![vec![SqlValue::Integer(3)], vec![SqlValue::Integer(4)]],
441 );
442
443 let join_iter = HashJoinIterator::new(left_iter, right, 0, 0).unwrap();
444
445 let results: Result<Vec<_>, _> = join_iter.collect();
446 let results = results.unwrap();
447
448 assert_eq!(results.len(), 0);
450 }
451
452 #[test]
453 fn test_hash_join_iterator_empty_tables() {
454 let left_result = create_test_from_result("users", vec![("id", DataType::Integer)], vec![]);
456
457 let left_iter = TableScanIterator::new(left_result.schema.clone(), left_result.into_rows());
458
459 let right = create_test_from_result("orders", vec![("user_id", DataType::Integer)], vec![]);
461
462 let join_iter = HashJoinIterator::new(left_iter, right, 0, 0).unwrap();
463
464 let results: Result<Vec<_>, _> = join_iter.collect();
465 let results = results.unwrap();
466
467 assert_eq!(results.len(), 0);
469 }
470
471 #[test]
472 fn test_hash_join_iterator_duplicate_keys() {
473 let left_result = create_test_from_result(
475 "users",
476 vec![("id", DataType::Integer), ("type", DataType::Varchar { max_length: Some(10) })],
477 vec![
478 vec![SqlValue::Integer(1), SqlValue::Varchar(arcstr::ArcStr::from("admin"))],
479 vec![SqlValue::Integer(1), SqlValue::Varchar(arcstr::ArcStr::from("user"))],
480 ],
481 );
482
483 let left_iter = TableScanIterator::new(left_result.schema.clone(), left_result.into_rows());
484
485 let right = create_test_from_result(
487 "orders",
488 vec![("user_id", DataType::Integer), ("amount", DataType::Integer)],
489 vec![
490 vec![SqlValue::Integer(1), SqlValue::Integer(100)],
491 vec![SqlValue::Integer(1), SqlValue::Integer(200)],
492 ],
493 );
494
495 let join_iter = HashJoinIterator::new(left_iter, right, 0, 0).unwrap();
496
497 let results: Result<Vec<_>, _> = join_iter.collect();
498 let results = results.unwrap();
499
500 assert_eq!(results.len(), 4);
502
503 for row in &results {
505 assert_eq!(row.values[0], SqlValue::Integer(1));
506 }
507 }
508
509 #[test]
510 fn test_hash_join_iterator_lazy_evaluation() {
511 struct CountingIterator {
515 schema: CombinedSchema,
516 rows: Vec<Row>,
517 index: usize,
518 consumed_count: std::sync::Arc<std::sync::Mutex<usize>>,
519 }
520
521 impl Iterator for CountingIterator {
522 type Item = Result<Row, ExecutorError>;
523
524 fn next(&mut self) -> Option<Self::Item> {
525 if self.index < self.rows.len() {
526 let row = self.rows[self.index].clone();
527 self.index += 1;
528 *self.consumed_count.lock().unwrap() += 1;
529 Some(Ok(row))
530 } else {
531 None
532 }
533 }
534 }
535
536 impl RowIterator for CountingIterator {
537 fn schema(&self) -> &CombinedSchema {
538 &self.schema
539 }
540 }
541
542 let consumed = std::sync::Arc::new(std::sync::Mutex::new(0));
543
544 let left_result = create_test_from_result(
545 "users",
546 vec![("id", DataType::Integer)],
547 vec![
548 vec![SqlValue::Integer(1)],
549 vec![SqlValue::Integer(2)],
550 vec![SqlValue::Integer(3)],
551 vec![SqlValue::Integer(4)],
552 vec![SqlValue::Integer(5)],
553 ],
554 );
555
556 let counting_iter = CountingIterator {
557 schema: left_result.schema.clone(),
558 rows: left_result.into_rows(),
559 index: 0,
560 consumed_count: consumed.clone(),
561 };
562
563 let right = create_test_from_result(
564 "orders",
565 vec![("user_id", DataType::Integer)],
566 vec![vec![SqlValue::Integer(1)], vec![SqlValue::Integer(2)]],
567 );
568
569 let join_iter = HashJoinIterator::new(counting_iter, right, 0, 0).unwrap();
570
571 let results: Vec<_> = join_iter.take(2).collect::<Result<Vec<_>, _>>().unwrap();
573 assert_eq!(results.len(), 2);
574
575 let consumed_count = *consumed.lock().unwrap();
578 assert!(consumed_count <= 3, "Expected at most 3 rows consumed, got {}", consumed_count);
579 }
580}