1#![allow(clippy::manual_is_multiple_of)]
7
8use ahash::AHashMap;
9
10use super::{combine_rows, FromResult};
11use crate::{
12 errors::ExecutorError,
13 schema::CombinedSchema,
14 select::RowIterator,
15 timeout::{TimeoutContext, CHECK_INTERVAL},
16};
17
18pub struct HashJoinIterator<L: RowIterator> {
32 left: L,
34 right_hash_table: AHashMap<vibesql_types::SqlValue, Vec<vibesql_storage::Row>>,
36 schema: CombinedSchema,
38 left_col_idx: usize,
40 #[allow(dead_code)]
42 right_col_idx: usize,
43 current_left_row: Option<vibesql_storage::Row>,
45 current_matches: Vec<vibesql_storage::Row>,
47 match_index: usize,
49 #[allow(dead_code)]
51 right_col_count: usize,
52 timeout_ctx: TimeoutContext,
54 iteration_count: usize,
56}
57
58impl<L: RowIterator> HashJoinIterator<L> {
59 #[allow(private_interfaces)]
71 pub fn new(
72 left: L,
73 right: FromResult,
74 left_col_idx: usize,
75 right_col_idx: usize,
76 ) -> Result<Self, ExecutorError> {
77 let right_table_name = right
79 .schema
80 .table_schemas
81 .keys()
82 .next()
83 .ok_or_else(|| ExecutorError::UnsupportedFeature("Complex JOIN".to_string()))?
84 .clone();
85
86 let right_schema = right
87 .schema
88 .table_schemas
89 .get(&right_table_name)
90 .ok_or_else(|| ExecutorError::UnsupportedFeature("Complex JOIN".to_string()))?
91 .1
92 .clone();
93
94 let right_col_count = right_schema.columns.len();
95
96 let combined_schema =
98 CombinedSchema::combine(left.schema().clone(), right_table_name, right_schema);
99
100 let timeout_ctx = TimeoutContext::new_default();
102
103 let mut hash_table: AHashMap<vibesql_types::SqlValue, Vec<vibesql_storage::Row>> = AHashMap::new();
106 let mut build_iterations = 0;
107
108 for row in right.into_rows() {
109 build_iterations += 1;
111 if build_iterations % CHECK_INTERVAL == 0 {
112 timeout_ctx.check()?;
113 }
114
115 let key = row.values[right_col_idx].clone();
116
117 if key != vibesql_types::SqlValue::Null {
119 hash_table.entry(key).or_default().push(row);
120 }
121 }
122
123 Ok(Self {
124 left,
125 right_hash_table: hash_table,
126 schema: combined_schema,
127 left_col_idx,
128 right_col_idx,
129 current_left_row: None,
130 current_matches: Vec::new(),
131 match_index: 0,
132 right_col_count,
133 timeout_ctx,
134 iteration_count: 0,
135 })
136 }
137
138 pub fn hash_table_size(&self) -> usize {
140 self.right_hash_table.values().map(|v| v.len()).sum()
141 }
142}
143
144impl<L: RowIterator> Iterator for HashJoinIterator<L> {
145 type Item = Result<vibesql_storage::Row, ExecutorError>;
146
147 fn next(&mut self) -> Option<Self::Item> {
148 loop {
149 self.iteration_count += 1;
151 if self.iteration_count % CHECK_INTERVAL == 0 {
152 if let Err(e) = self.timeout_ctx.check() {
153 return Some(Err(e));
154 }
155 }
156
157 if self.match_index < self.current_matches.len() {
159 let right_row = &self.current_matches[self.match_index];
160 self.match_index += 1;
161
162 if let Some(ref left_row) = self.current_left_row {
164 let combined_row = combine_rows(left_row, right_row);
165 return Some(Ok(combined_row));
166 }
167 }
168
169 match self.left.next() {
171 Some(Ok(left_row)) => {
172 let key = &left_row.values[self.left_col_idx];
173
174 if key == &vibesql_types::SqlValue::Null {
176 continue;
178 }
179
180 if let Some(matches) = self.right_hash_table.get(key) {
182 self.current_left_row = Some(left_row);
184 self.current_matches = matches.clone();
185 self.match_index = 0;
186 } else {
188 continue;
191 }
192 }
193 Some(Err(e)) => {
194 return Some(Err(e));
196 }
197 None => {
198 return None;
200 }
201 }
202 }
203 }
204}
205
206impl<L: RowIterator> RowIterator for HashJoinIterator<L> {
207 fn schema(&self) -> &CombinedSchema {
208 &self.schema
209 }
210}
211
212#[cfg(test)]
213mod tests {
214 use super::*;
215 use vibesql_catalog::{ColumnSchema, TableSchema};
216 use crate::select::TableScanIterator;
217 use vibesql_storage::Row;
218 use vibesql_types::{DataType, SqlValue};
219
220 fn create_test_from_result(
222 table_name: &str,
223 columns: Vec<(&str, DataType)>,
224 rows: Vec<Vec<SqlValue>>,
225 ) -> FromResult {
226 let schema = TableSchema::new(
227 table_name.to_string(),
228 columns
229 .iter()
230 .map(|(name, dtype)| {
231 ColumnSchema::new(
232 name.to_string(),
233 dtype.clone(),
234 true, )
236 })
237 .collect(),
238 );
239
240 let combined_schema = CombinedSchema::from_table(table_name.to_string(), schema);
241 let rows = rows.into_iter().map(Row::new).collect();
242
243 FromResult::from_rows(combined_schema, rows)
244 }
245
246 #[test]
247 fn test_hash_join_iterator_simple() {
248 let left_result = create_test_from_result(
250 "users",
251 vec![("id", DataType::Integer), ("name", DataType::Varchar { max_length: Some(50) })],
252 vec![
253 vec![SqlValue::Integer(1), SqlValue::Varchar("Alice".to_string())],
254 vec![SqlValue::Integer(2), SqlValue::Varchar("Bob".to_string())],
255 vec![SqlValue::Integer(3), SqlValue::Varchar("Charlie".to_string())],
256 ],
257 );
258
259 let left_iter = TableScanIterator::new(left_result.schema.clone(), left_result.into_rows());
260
261 let right = create_test_from_result(
263 "orders",
264 vec![("user_id", DataType::Integer), ("amount", DataType::Integer)],
265 vec![
266 vec![SqlValue::Integer(1), SqlValue::Integer(100)],
267 vec![SqlValue::Integer(2), SqlValue::Integer(200)],
268 vec![SqlValue::Integer(1), SqlValue::Integer(150)],
269 ],
270 );
271
272 let join_iter = HashJoinIterator::new(left_iter, right, 0, 0).unwrap();
274
275 let results: Result<Vec<_>, _> = join_iter.collect();
277 let results = results.unwrap();
278
279 assert_eq!(results.len(), 3);
281
282 for row in &results {
284 assert_eq!(row.values.len(), 4);
285 }
286
287 let alice_orders: Vec<_> =
290 results.iter().filter(|r| r.values[0] == SqlValue::Integer(1)).collect();
291 assert_eq!(alice_orders.len(), 2);
292
293 let bob_orders: Vec<_> =
295 results.iter().filter(|r| r.values[0] == SqlValue::Integer(2)).collect();
296 assert_eq!(bob_orders.len(), 1);
297
298 let charlie_orders: Vec<_> =
300 results.iter().filter(|r| r.values[0] == SqlValue::Integer(3)).collect();
301 assert_eq!(charlie_orders.len(), 0);
302 }
303
304 #[test]
305 fn test_hash_join_iterator_null_values() {
306 let left_result = create_test_from_result(
308 "users",
309 vec![("id", DataType::Integer), ("name", DataType::Varchar { max_length: Some(50) })],
310 vec![
311 vec![SqlValue::Integer(1), SqlValue::Varchar("Alice".to_string())],
312 vec![SqlValue::Null, SqlValue::Varchar("Unknown".to_string())],
313 ],
314 );
315
316 let left_iter = TableScanIterator::new(left_result.schema.clone(), left_result.into_rows());
317
318 let right = create_test_from_result(
320 "orders",
321 vec![("user_id", DataType::Integer), ("amount", DataType::Integer)],
322 vec![
323 vec![SqlValue::Integer(1), SqlValue::Integer(100)],
324 vec![SqlValue::Null, SqlValue::Integer(200)],
325 ],
326 );
327
328 let join_iter = HashJoinIterator::new(left_iter, right, 0, 0).unwrap();
329
330 let results: Result<Vec<_>, _> = join_iter.collect();
331 let results = results.unwrap();
332
333 assert_eq!(results.len(), 1);
336 assert_eq!(results[0].values[0], SqlValue::Integer(1)); assert_eq!(results[0].values[1], SqlValue::Varchar("Alice".to_string())); assert_eq!(results[0].values[2], SqlValue::Integer(1)); assert_eq!(results[0].values[3], SqlValue::Integer(100)); }
341
342 #[test]
343 fn test_hash_join_iterator_no_matches() {
344 let left_result = create_test_from_result(
346 "users",
347 vec![("id", DataType::Integer)],
348 vec![vec![SqlValue::Integer(1)], vec![SqlValue::Integer(2)]],
349 );
350
351 let left_iter = TableScanIterator::new(left_result.schema.clone(), left_result.into_rows());
352
353 let right = create_test_from_result(
355 "orders",
356 vec![("user_id", DataType::Integer)],
357 vec![vec![SqlValue::Integer(3)], vec![SqlValue::Integer(4)]],
358 );
359
360 let join_iter = HashJoinIterator::new(left_iter, right, 0, 0).unwrap();
361
362 let results: Result<Vec<_>, _> = join_iter.collect();
363 let results = results.unwrap();
364
365 assert_eq!(results.len(), 0);
367 }
368
369 #[test]
370 fn test_hash_join_iterator_empty_tables() {
371 let left_result = create_test_from_result("users", vec![("id", DataType::Integer)], vec![]);
373
374 let left_iter = TableScanIterator::new(left_result.schema.clone(), left_result.into_rows());
375
376 let right = create_test_from_result("orders", vec![("user_id", DataType::Integer)], vec![]);
378
379 let join_iter = HashJoinIterator::new(left_iter, right, 0, 0).unwrap();
380
381 let results: Result<Vec<_>, _> = join_iter.collect();
382 let results = results.unwrap();
383
384 assert_eq!(results.len(), 0);
386 }
387
388 #[test]
389 fn test_hash_join_iterator_duplicate_keys() {
390 let left_result = create_test_from_result(
392 "users",
393 vec![("id", DataType::Integer), ("type", DataType::Varchar { max_length: Some(10) })],
394 vec![
395 vec![SqlValue::Integer(1), SqlValue::Varchar("admin".to_string())],
396 vec![SqlValue::Integer(1), SqlValue::Varchar("user".to_string())],
397 ],
398 );
399
400 let left_iter = TableScanIterator::new(left_result.schema.clone(), left_result.into_rows());
401
402 let right = create_test_from_result(
404 "orders",
405 vec![("user_id", DataType::Integer), ("amount", DataType::Integer)],
406 vec![
407 vec![SqlValue::Integer(1), SqlValue::Integer(100)],
408 vec![SqlValue::Integer(1), SqlValue::Integer(200)],
409 ],
410 );
411
412 let join_iter = HashJoinIterator::new(left_iter, right, 0, 0).unwrap();
413
414 let results: Result<Vec<_>, _> = join_iter.collect();
415 let results = results.unwrap();
416
417 assert_eq!(results.len(), 4);
419
420 for row in &results {
422 assert_eq!(row.values[0], SqlValue::Integer(1));
423 }
424 }
425
426 #[test]
427 fn test_hash_join_iterator_lazy_evaluation() {
428 struct CountingIterator {
432 schema: CombinedSchema,
433 rows: Vec<Row>,
434 index: usize,
435 consumed_count: std::sync::Arc<std::sync::Mutex<usize>>,
436 }
437
438 impl Iterator for CountingIterator {
439 type Item = Result<Row, ExecutorError>;
440
441 fn next(&mut self) -> Option<Self::Item> {
442 if self.index < self.rows.len() {
443 let row = self.rows[self.index].clone();
444 self.index += 1;
445 *self.consumed_count.lock().unwrap() += 1;
446 Some(Ok(row))
447 } else {
448 None
449 }
450 }
451 }
452
453 impl RowIterator for CountingIterator {
454 fn schema(&self) -> &CombinedSchema {
455 &self.schema
456 }
457 }
458
459 let consumed = std::sync::Arc::new(std::sync::Mutex::new(0));
460
461 let left_result = create_test_from_result(
462 "users",
463 vec![("id", DataType::Integer)],
464 vec![
465 vec![SqlValue::Integer(1)],
466 vec![SqlValue::Integer(2)],
467 vec![SqlValue::Integer(3)],
468 vec![SqlValue::Integer(4)],
469 vec![SqlValue::Integer(5)],
470 ],
471 );
472
473 let counting_iter = CountingIterator {
474 schema: left_result.schema.clone(),
475 rows: left_result.into_rows(),
476 index: 0,
477 consumed_count: consumed.clone(),
478 };
479
480 let right = create_test_from_result(
481 "orders",
482 vec![("user_id", DataType::Integer)],
483 vec![vec![SqlValue::Integer(1)], vec![SqlValue::Integer(2)]],
484 );
485
486 let join_iter = HashJoinIterator::new(counting_iter, right, 0, 0).unwrap();
487
488 let results: Vec<_> = join_iter.take(2).collect::<Result<Vec<_>, _>>().unwrap();
490 assert_eq!(results.len(), 2);
491
492 let consumed_count = *consumed.lock().unwrap();
495 assert!(
496 consumed_count <= 3,
497 "Expected at most 3 rows consumed, got {}",
498 consumed_count
499 );
500 }
501}