1#![allow(clippy::manual_is_multiple_of)]
18
19use std::hash::{Hash, Hasher};
20
21use ahash::{AHashMap, AHasher};
22
23use super::{combine_rows, BloomFilter, FromResult};
24use crate::{
25 errors::ExecutorError,
26 schema::CombinedSchema,
27 select::RowIterator,
28 timeout::{TimeoutContext, CHECK_INTERVAL},
29};
30
31const BLOOM_FILTER_MIN_ROWS: usize = 100;
34
35const BLOOM_FILTER_FPR: f64 = 0.01;
38
39#[inline]
42fn hash_sql_value(value: &vibesql_types::SqlValue) -> u64 {
43 let mut hasher = AHasher::default();
44 value.hash(&mut hasher);
45 hasher.finish()
46}
47
48pub struct HashJoinIterator<L: RowIterator> {
66 left: L,
68 right_hash_table: AHashMap<vibesql_types::SqlValue, Vec<vibesql_storage::Row>>,
70 bloom_filter: Option<BloomFilter>,
72 schema: CombinedSchema,
74 left_col_idx: usize,
76 #[allow(dead_code)]
78 right_col_idx: usize,
79 current_left_row: Option<vibesql_storage::Row>,
81 current_matches: Vec<vibesql_storage::Row>,
83 match_index: usize,
85 #[allow(dead_code)]
87 right_col_count: usize,
88 timeout_ctx: TimeoutContext,
90 iteration_count: usize,
92 #[allow(dead_code)]
94 bloom_rejections: usize,
95}
96
97impl<L: RowIterator> HashJoinIterator<L> {
98 #[allow(private_interfaces)]
110 pub fn new(
111 left: L,
112 right: FromResult,
113 left_col_idx: usize,
114 right_col_idx: usize,
115 ) -> Result<Self, ExecutorError> {
116 let right_table_name = right
118 .schema
119 .table_schemas
120 .keys()
121 .next()
122 .ok_or_else(|| ExecutorError::UnsupportedFeature("Complex JOIN".to_string()))?
123 .clone();
124
125 let right_schema = right
126 .schema
127 .table_schemas
128 .get(&right_table_name)
129 .ok_or_else(|| ExecutorError::UnsupportedFeature("Complex JOIN".to_string()))?
130 .1
131 .clone();
132
133 let right_col_count = right_schema.columns.len();
134
135 let combined_schema =
137 CombinedSchema::combine(left.schema().clone(), right_table_name, right_schema);
138
139 let timeout_ctx = TimeoutContext::new_default();
141
142 let right_rows = right.into_rows();
145 let num_build_rows = right_rows.len();
146
147 let mut hash_table: AHashMap<vibesql_types::SqlValue, Vec<vibesql_storage::Row>> =
148 AHashMap::new();
149
150 let bloom_disabled = std::env::var("VIBESQL_DISABLE_BLOOM_FILTER").is_ok();
153 let mut bloom_filter = if !bloom_disabled && num_build_rows >= BLOOM_FILTER_MIN_ROWS {
154 Some(BloomFilter::new(num_build_rows, BLOOM_FILTER_FPR))
155 } else {
156 None
157 };
158
159 let mut build_iterations = 0;
160
161 for row in right_rows {
162 build_iterations += 1;
164 if build_iterations % CHECK_INTERVAL == 0 {
165 timeout_ctx.check()?;
166 }
167
168 let key = row.values[right_col_idx].clone();
169
170 if key != vibesql_types::SqlValue::Null {
172 if let Some(ref mut bf) = bloom_filter {
174 let hash = hash_sql_value(&key);
176 bf.insert_hash(hash);
177 }
178
179 hash_table.entry(key).or_default().push(row);
180 }
181 }
182
183 Ok(Self {
184 left,
185 right_hash_table: hash_table,
186 bloom_filter,
187 schema: combined_schema,
188 left_col_idx,
189 right_col_idx,
190 current_left_row: None,
191 current_matches: Vec::new(),
192 match_index: 0,
193 right_col_count,
194 timeout_ctx,
195 iteration_count: 0,
196 bloom_rejections: 0,
197 })
198 }
199
200 pub fn hash_table_size(&self) -> usize {
202 self.right_hash_table.values().map(|v| v.len()).sum()
203 }
204}
205
206impl<L: RowIterator> Iterator for HashJoinIterator<L> {
207 type Item = Result<vibesql_storage::Row, ExecutorError>;
208
209 fn next(&mut self) -> Option<Self::Item> {
210 loop {
211 self.iteration_count += 1;
213 if self.iteration_count % CHECK_INTERVAL == 0 {
214 if let Err(e) = self.timeout_ctx.check() {
215 return Some(Err(e));
216 }
217 }
218
219 if self.match_index < self.current_matches.len() {
221 let right_row = &self.current_matches[self.match_index];
222 self.match_index += 1;
223
224 if let Some(ref left_row) = self.current_left_row {
226 let combined_row = combine_rows(left_row, right_row);
227 return Some(Ok(combined_row));
228 }
229 }
230
231 match self.left.next() {
233 Some(Ok(left_row)) => {
234 let key = &left_row.values[self.left_col_idx];
235
236 if key == &vibesql_types::SqlValue::Null {
238 continue;
240 }
241
242 if let Some(ref bf) = self.bloom_filter {
247 let hash = hash_sql_value(key);
248 if !bf.might_contain_hash(hash) {
249 self.bloom_rejections += 1;
251 continue;
252 }
253 }
254
255 if let Some(matches) = self.right_hash_table.get(key) {
257 self.current_left_row = Some(left_row);
259 self.current_matches = matches.clone();
260 self.match_index = 0;
261 } else {
263 continue;
267 }
268 }
269 Some(Err(e)) => {
270 return Some(Err(e));
272 }
273 None => {
274 return None;
276 }
277 }
278 }
279 }
280}
281
282impl<L: RowIterator> RowIterator for HashJoinIterator<L> {
283 fn schema(&self) -> &CombinedSchema {
284 &self.schema
285 }
286}
287
288#[cfg(test)]
289mod tests {
290 use super::*;
291 use crate::select::TableScanIterator;
292 use vibesql_catalog::{ColumnSchema, TableSchema};
293 use vibesql_storage::Row;
294 use vibesql_types::{DataType, SqlValue};
295
296 fn create_test_from_result(
298 table_name: &str,
299 columns: Vec<(&str, DataType)>,
300 rows: Vec<Vec<SqlValue>>,
301 ) -> FromResult {
302 let schema = TableSchema::new(
303 table_name.to_string(),
304 columns
305 .iter()
306 .map(|(name, dtype)| {
307 ColumnSchema::new(
308 name.to_string(),
309 dtype.clone(),
310 true, )
312 })
313 .collect(),
314 );
315
316 let combined_schema = CombinedSchema::from_table(table_name.to_string(), schema);
317 let rows = rows.into_iter().map(Row::new).collect();
318
319 FromResult::from_rows(combined_schema, rows)
320 }
321
322 #[test]
323 fn test_hash_join_iterator_simple() {
324 let left_result = create_test_from_result(
326 "users",
327 vec![("id", DataType::Integer), ("name", DataType::Varchar { max_length: Some(50) })],
328 vec![
329 vec![SqlValue::Integer(1), SqlValue::Varchar(arcstr::ArcStr::from("Alice"))],
330 vec![SqlValue::Integer(2), SqlValue::Varchar(arcstr::ArcStr::from("Bob"))],
331 vec![SqlValue::Integer(3), SqlValue::Varchar(arcstr::ArcStr::from("Charlie"))],
332 ],
333 );
334
335 let left_iter = TableScanIterator::new(left_result.schema.clone(), left_result.into_rows());
336
337 let right = create_test_from_result(
339 "orders",
340 vec![("user_id", DataType::Integer), ("amount", DataType::Integer)],
341 vec![
342 vec![SqlValue::Integer(1), SqlValue::Integer(100)],
343 vec![SqlValue::Integer(2), SqlValue::Integer(200)],
344 vec![SqlValue::Integer(1), SqlValue::Integer(150)],
345 ],
346 );
347
348 let join_iter = HashJoinIterator::new(left_iter, right, 0, 0).unwrap();
350
351 let results: Result<Vec<_>, _> = join_iter.collect();
353 let results = results.unwrap();
354
355 assert_eq!(results.len(), 3);
357
358 for row in &results {
360 assert_eq!(row.values.len(), 4);
361 }
362
363 let alice_orders: Vec<_> =
366 results.iter().filter(|r| r.values[0] == SqlValue::Integer(1)).collect();
367 assert_eq!(alice_orders.len(), 2);
368
369 let bob_orders: Vec<_> =
371 results.iter().filter(|r| r.values[0] == SqlValue::Integer(2)).collect();
372 assert_eq!(bob_orders.len(), 1);
373
374 let charlie_orders: Vec<_> =
376 results.iter().filter(|r| r.values[0] == SqlValue::Integer(3)).collect();
377 assert_eq!(charlie_orders.len(), 0);
378 }
379
380 #[test]
381 fn test_hash_join_iterator_null_values() {
382 let left_result = create_test_from_result(
384 "users",
385 vec![("id", DataType::Integer), ("name", DataType::Varchar { max_length: Some(50) })],
386 vec![
387 vec![SqlValue::Integer(1), SqlValue::Varchar(arcstr::ArcStr::from("Alice"))],
388 vec![SqlValue::Null, SqlValue::Varchar(arcstr::ArcStr::from("Unknown"))],
389 ],
390 );
391
392 let left_iter = TableScanIterator::new(left_result.schema.clone(), left_result.into_rows());
393
394 let right = create_test_from_result(
396 "orders",
397 vec![("user_id", DataType::Integer), ("amount", DataType::Integer)],
398 vec![
399 vec![SqlValue::Integer(1), SqlValue::Integer(100)],
400 vec![SqlValue::Null, SqlValue::Integer(200)],
401 ],
402 );
403
404 let join_iter = HashJoinIterator::new(left_iter, right, 0, 0).unwrap();
405
406 let results: Result<Vec<_>, _> = join_iter.collect();
407 let results = results.unwrap();
408
409 assert_eq!(results.len(), 1);
412 assert_eq!(results[0].values[0], SqlValue::Integer(1)); assert_eq!(results[0].values[1], SqlValue::Varchar(arcstr::ArcStr::from("Alice"))); assert_eq!(results[0].values[2], SqlValue::Integer(1)); assert_eq!(results[0].values[3], SqlValue::Integer(100)); }
417
418 #[test]
419 fn test_hash_join_iterator_no_matches() {
420 let left_result = create_test_from_result(
422 "users",
423 vec![("id", DataType::Integer)],
424 vec![vec![SqlValue::Integer(1)], vec![SqlValue::Integer(2)]],
425 );
426
427 let left_iter = TableScanIterator::new(left_result.schema.clone(), left_result.into_rows());
428
429 let right = create_test_from_result(
431 "orders",
432 vec![("user_id", DataType::Integer)],
433 vec![vec![SqlValue::Integer(3)], vec![SqlValue::Integer(4)]],
434 );
435
436 let join_iter = HashJoinIterator::new(left_iter, right, 0, 0).unwrap();
437
438 let results: Result<Vec<_>, _> = join_iter.collect();
439 let results = results.unwrap();
440
441 assert_eq!(results.len(), 0);
443 }
444
445 #[test]
446 fn test_hash_join_iterator_empty_tables() {
447 let left_result = create_test_from_result("users", vec![("id", DataType::Integer)], vec![]);
449
450 let left_iter = TableScanIterator::new(left_result.schema.clone(), left_result.into_rows());
451
452 let right = create_test_from_result("orders", vec![("user_id", DataType::Integer)], vec![]);
454
455 let join_iter = HashJoinIterator::new(left_iter, right, 0, 0).unwrap();
456
457 let results: Result<Vec<_>, _> = join_iter.collect();
458 let results = results.unwrap();
459
460 assert_eq!(results.len(), 0);
462 }
463
464 #[test]
465 fn test_hash_join_iterator_duplicate_keys() {
466 let left_result = create_test_from_result(
468 "users",
469 vec![("id", DataType::Integer), ("type", DataType::Varchar { max_length: Some(10) })],
470 vec![
471 vec![SqlValue::Integer(1), SqlValue::Varchar(arcstr::ArcStr::from("admin"))],
472 vec![SqlValue::Integer(1), SqlValue::Varchar(arcstr::ArcStr::from("user"))],
473 ],
474 );
475
476 let left_iter = TableScanIterator::new(left_result.schema.clone(), left_result.into_rows());
477
478 let right = create_test_from_result(
480 "orders",
481 vec![("user_id", DataType::Integer), ("amount", DataType::Integer)],
482 vec![
483 vec![SqlValue::Integer(1), SqlValue::Integer(100)],
484 vec![SqlValue::Integer(1), SqlValue::Integer(200)],
485 ],
486 );
487
488 let join_iter = HashJoinIterator::new(left_iter, right, 0, 0).unwrap();
489
490 let results: Result<Vec<_>, _> = join_iter.collect();
491 let results = results.unwrap();
492
493 assert_eq!(results.len(), 4);
495
496 for row in &results {
498 assert_eq!(row.values[0], SqlValue::Integer(1));
499 }
500 }
501
502 #[test]
503 fn test_hash_join_iterator_lazy_evaluation() {
504 struct CountingIterator {
508 schema: CombinedSchema,
509 rows: Vec<Row>,
510 index: usize,
511 consumed_count: std::sync::Arc<std::sync::Mutex<usize>>,
512 }
513
514 impl Iterator for CountingIterator {
515 type Item = Result<Row, ExecutorError>;
516
517 fn next(&mut self) -> Option<Self::Item> {
518 if self.index < self.rows.len() {
519 let row = self.rows[self.index].clone();
520 self.index += 1;
521 *self.consumed_count.lock().unwrap() += 1;
522 Some(Ok(row))
523 } else {
524 None
525 }
526 }
527 }
528
529 impl RowIterator for CountingIterator {
530 fn schema(&self) -> &CombinedSchema {
531 &self.schema
532 }
533 }
534
535 let consumed = std::sync::Arc::new(std::sync::Mutex::new(0));
536
537 let left_result = create_test_from_result(
538 "users",
539 vec![("id", DataType::Integer)],
540 vec![
541 vec![SqlValue::Integer(1)],
542 vec![SqlValue::Integer(2)],
543 vec![SqlValue::Integer(3)],
544 vec![SqlValue::Integer(4)],
545 vec![SqlValue::Integer(5)],
546 ],
547 );
548
549 let counting_iter = CountingIterator {
550 schema: left_result.schema.clone(),
551 rows: left_result.into_rows(),
552 index: 0,
553 consumed_count: consumed.clone(),
554 };
555
556 let right = create_test_from_result(
557 "orders",
558 vec![("user_id", DataType::Integer)],
559 vec![vec![SqlValue::Integer(1)], vec![SqlValue::Integer(2)]],
560 );
561
562 let join_iter = HashJoinIterator::new(counting_iter, right, 0, 0).unwrap();
563
564 let results: Vec<_> = join_iter.take(2).collect::<Result<Vec<_>, _>>().unwrap();
566 assert_eq!(results.len(), 2);
567
568 let consumed_count = *consumed.lock().unwrap();
571 assert!(consumed_count <= 3, "Expected at most 3 rows consumed, got {}", consumed_count);
572 }
573}