1#![allow(clippy::manual_is_multiple_of)]
7
8use ahash::AHashMap;
9
10use super::{combine_rows, FromResult};
11use crate::{
12 errors::ExecutorError,
13 schema::CombinedSchema,
14 select::RowIterator,
15 timeout::{TimeoutContext, CHECK_INTERVAL},
16};
17
18pub struct HashJoinIterator<L: RowIterator> {
32 left: L,
34 right_hash_table: AHashMap<vibesql_types::SqlValue, Vec<vibesql_storage::Row>>,
36 schema: CombinedSchema,
38 left_col_idx: usize,
40 #[allow(dead_code)]
42 right_col_idx: usize,
43 current_left_row: Option<vibesql_storage::Row>,
45 current_matches: Vec<vibesql_storage::Row>,
47 match_index: usize,
49 #[allow(dead_code)]
51 right_col_count: usize,
52 timeout_ctx: TimeoutContext,
54 iteration_count: usize,
56}
57
58impl<L: RowIterator> HashJoinIterator<L> {
59 #[allow(private_interfaces)]
71 pub fn new(
72 left: L,
73 right: FromResult,
74 left_col_idx: usize,
75 right_col_idx: usize,
76 ) -> Result<Self, ExecutorError> {
77 let right_table_name = right
79 .schema
80 .table_schemas
81 .keys()
82 .next()
83 .ok_or_else(|| ExecutorError::UnsupportedFeature("Complex JOIN".to_string()))?
84 .clone();
85
86 let right_schema = right
87 .schema
88 .table_schemas
89 .get(&right_table_name)
90 .ok_or_else(|| ExecutorError::UnsupportedFeature("Complex JOIN".to_string()))?
91 .1
92 .clone();
93
94 let right_col_count = right_schema.columns.len();
95
96 let combined_schema =
98 CombinedSchema::combine(left.schema().clone(), right_table_name, right_schema);
99
100 let timeout_ctx = TimeoutContext::new_default();
102
103 let mut hash_table: AHashMap<vibesql_types::SqlValue, Vec<vibesql_storage::Row>> =
106 AHashMap::new();
107 let mut build_iterations = 0;
108
109 for row in right.into_rows() {
110 build_iterations += 1;
112 if build_iterations % CHECK_INTERVAL == 0 {
113 timeout_ctx.check()?;
114 }
115
116 let key = row.values[right_col_idx].clone();
117
118 if key != vibesql_types::SqlValue::Null {
120 hash_table.entry(key).or_default().push(row);
121 }
122 }
123
124 Ok(Self {
125 left,
126 right_hash_table: hash_table,
127 schema: combined_schema,
128 left_col_idx,
129 right_col_idx,
130 current_left_row: None,
131 current_matches: Vec::new(),
132 match_index: 0,
133 right_col_count,
134 timeout_ctx,
135 iteration_count: 0,
136 })
137 }
138
139 pub fn hash_table_size(&self) -> usize {
141 self.right_hash_table.values().map(|v| v.len()).sum()
142 }
143}
144
145impl<L: RowIterator> Iterator for HashJoinIterator<L> {
146 type Item = Result<vibesql_storage::Row, ExecutorError>;
147
148 fn next(&mut self) -> Option<Self::Item> {
149 loop {
150 self.iteration_count += 1;
152 if self.iteration_count % CHECK_INTERVAL == 0 {
153 if let Err(e) = self.timeout_ctx.check() {
154 return Some(Err(e));
155 }
156 }
157
158 if self.match_index < self.current_matches.len() {
160 let right_row = &self.current_matches[self.match_index];
161 self.match_index += 1;
162
163 if let Some(ref left_row) = self.current_left_row {
165 let combined_row = combine_rows(left_row, right_row);
166 return Some(Ok(combined_row));
167 }
168 }
169
170 match self.left.next() {
172 Some(Ok(left_row)) => {
173 let key = &left_row.values[self.left_col_idx];
174
175 if key == &vibesql_types::SqlValue::Null {
177 continue;
179 }
180
181 if let Some(matches) = self.right_hash_table.get(key) {
183 self.current_left_row = Some(left_row);
185 self.current_matches = matches.clone();
186 self.match_index = 0;
187 } else {
189 continue;
192 }
193 }
194 Some(Err(e)) => {
195 return Some(Err(e));
197 }
198 None => {
199 return None;
201 }
202 }
203 }
204 }
205}
206
207impl<L: RowIterator> RowIterator for HashJoinIterator<L> {
208 fn schema(&self) -> &CombinedSchema {
209 &self.schema
210 }
211}
212
213#[cfg(test)]
214mod tests {
215 use super::*;
216 use crate::select::TableScanIterator;
217 use vibesql_catalog::{ColumnSchema, TableSchema};
218 use vibesql_storage::Row;
219 use vibesql_types::{DataType, SqlValue};
220
221 fn create_test_from_result(
223 table_name: &str,
224 columns: Vec<(&str, DataType)>,
225 rows: Vec<Vec<SqlValue>>,
226 ) -> FromResult {
227 let schema = TableSchema::new(
228 table_name.to_string(),
229 columns
230 .iter()
231 .map(|(name, dtype)| {
232 ColumnSchema::new(
233 name.to_string(),
234 dtype.clone(),
235 true, )
237 })
238 .collect(),
239 );
240
241 let combined_schema = CombinedSchema::from_table(table_name.to_string(), schema);
242 let rows = rows.into_iter().map(Row::new).collect();
243
244 FromResult::from_rows(combined_schema, rows)
245 }
246
247 #[test]
248 fn test_hash_join_iterator_simple() {
249 let left_result = create_test_from_result(
251 "users",
252 vec![("id", DataType::Integer), ("name", DataType::Varchar { max_length: Some(50) })],
253 vec![
254 vec![SqlValue::Integer(1), SqlValue::Varchar("Alice".to_string())],
255 vec![SqlValue::Integer(2), SqlValue::Varchar("Bob".to_string())],
256 vec![SqlValue::Integer(3), SqlValue::Varchar("Charlie".to_string())],
257 ],
258 );
259
260 let left_iter = TableScanIterator::new(left_result.schema.clone(), left_result.into_rows());
261
262 let right = create_test_from_result(
264 "orders",
265 vec![("user_id", DataType::Integer), ("amount", DataType::Integer)],
266 vec![
267 vec![SqlValue::Integer(1), SqlValue::Integer(100)],
268 vec![SqlValue::Integer(2), SqlValue::Integer(200)],
269 vec![SqlValue::Integer(1), SqlValue::Integer(150)],
270 ],
271 );
272
273 let join_iter = HashJoinIterator::new(left_iter, right, 0, 0).unwrap();
275
276 let results: Result<Vec<_>, _> = join_iter.collect();
278 let results = results.unwrap();
279
280 assert_eq!(results.len(), 3);
282
283 for row in &results {
285 assert_eq!(row.values.len(), 4);
286 }
287
288 let alice_orders: Vec<_> =
291 results.iter().filter(|r| r.values[0] == SqlValue::Integer(1)).collect();
292 assert_eq!(alice_orders.len(), 2);
293
294 let bob_orders: Vec<_> =
296 results.iter().filter(|r| r.values[0] == SqlValue::Integer(2)).collect();
297 assert_eq!(bob_orders.len(), 1);
298
299 let charlie_orders: Vec<_> =
301 results.iter().filter(|r| r.values[0] == SqlValue::Integer(3)).collect();
302 assert_eq!(charlie_orders.len(), 0);
303 }
304
305 #[test]
306 fn test_hash_join_iterator_null_values() {
307 let left_result = create_test_from_result(
309 "users",
310 vec![("id", DataType::Integer), ("name", DataType::Varchar { max_length: Some(50) })],
311 vec![
312 vec![SqlValue::Integer(1), SqlValue::Varchar("Alice".to_string())],
313 vec![SqlValue::Null, SqlValue::Varchar("Unknown".to_string())],
314 ],
315 );
316
317 let left_iter = TableScanIterator::new(left_result.schema.clone(), left_result.into_rows());
318
319 let right = create_test_from_result(
321 "orders",
322 vec![("user_id", DataType::Integer), ("amount", DataType::Integer)],
323 vec![
324 vec![SqlValue::Integer(1), SqlValue::Integer(100)],
325 vec![SqlValue::Null, SqlValue::Integer(200)],
326 ],
327 );
328
329 let join_iter = HashJoinIterator::new(left_iter, right, 0, 0).unwrap();
330
331 let results: Result<Vec<_>, _> = join_iter.collect();
332 let results = results.unwrap();
333
334 assert_eq!(results.len(), 1);
337 assert_eq!(results[0].values[0], SqlValue::Integer(1)); assert_eq!(results[0].values[1], SqlValue::Varchar("Alice".to_string())); assert_eq!(results[0].values[2], SqlValue::Integer(1)); assert_eq!(results[0].values[3], SqlValue::Integer(100)); }
342
343 #[test]
344 fn test_hash_join_iterator_no_matches() {
345 let left_result = create_test_from_result(
347 "users",
348 vec![("id", DataType::Integer)],
349 vec![vec![SqlValue::Integer(1)], vec![SqlValue::Integer(2)]],
350 );
351
352 let left_iter = TableScanIterator::new(left_result.schema.clone(), left_result.into_rows());
353
354 let right = create_test_from_result(
356 "orders",
357 vec![("user_id", DataType::Integer)],
358 vec![vec![SqlValue::Integer(3)], vec![SqlValue::Integer(4)]],
359 );
360
361 let join_iter = HashJoinIterator::new(left_iter, right, 0, 0).unwrap();
362
363 let results: Result<Vec<_>, _> = join_iter.collect();
364 let results = results.unwrap();
365
366 assert_eq!(results.len(), 0);
368 }
369
370 #[test]
371 fn test_hash_join_iterator_empty_tables() {
372 let left_result = create_test_from_result("users", vec![("id", DataType::Integer)], vec![]);
374
375 let left_iter = TableScanIterator::new(left_result.schema.clone(), left_result.into_rows());
376
377 let right = create_test_from_result("orders", vec![("user_id", DataType::Integer)], vec![]);
379
380 let join_iter = HashJoinIterator::new(left_iter, right, 0, 0).unwrap();
381
382 let results: Result<Vec<_>, _> = join_iter.collect();
383 let results = results.unwrap();
384
385 assert_eq!(results.len(), 0);
387 }
388
389 #[test]
390 fn test_hash_join_iterator_duplicate_keys() {
391 let left_result = create_test_from_result(
393 "users",
394 vec![("id", DataType::Integer), ("type", DataType::Varchar { max_length: Some(10) })],
395 vec![
396 vec![SqlValue::Integer(1), SqlValue::Varchar("admin".to_string())],
397 vec![SqlValue::Integer(1), SqlValue::Varchar("user".to_string())],
398 ],
399 );
400
401 let left_iter = TableScanIterator::new(left_result.schema.clone(), left_result.into_rows());
402
403 let right = create_test_from_result(
405 "orders",
406 vec![("user_id", DataType::Integer), ("amount", DataType::Integer)],
407 vec![
408 vec![SqlValue::Integer(1), SqlValue::Integer(100)],
409 vec![SqlValue::Integer(1), SqlValue::Integer(200)],
410 ],
411 );
412
413 let join_iter = HashJoinIterator::new(left_iter, right, 0, 0).unwrap();
414
415 let results: Result<Vec<_>, _> = join_iter.collect();
416 let results = results.unwrap();
417
418 assert_eq!(results.len(), 4);
420
421 for row in &results {
423 assert_eq!(row.values[0], SqlValue::Integer(1));
424 }
425 }
426
427 #[test]
428 fn test_hash_join_iterator_lazy_evaluation() {
429 struct CountingIterator {
433 schema: CombinedSchema,
434 rows: Vec<Row>,
435 index: usize,
436 consumed_count: std::sync::Arc<std::sync::Mutex<usize>>,
437 }
438
439 impl Iterator for CountingIterator {
440 type Item = Result<Row, ExecutorError>;
441
442 fn next(&mut self) -> Option<Self::Item> {
443 if self.index < self.rows.len() {
444 let row = self.rows[self.index].clone();
445 self.index += 1;
446 *self.consumed_count.lock().unwrap() += 1;
447 Some(Ok(row))
448 } else {
449 None
450 }
451 }
452 }
453
454 impl RowIterator for CountingIterator {
455 fn schema(&self) -> &CombinedSchema {
456 &self.schema
457 }
458 }
459
460 let consumed = std::sync::Arc::new(std::sync::Mutex::new(0));
461
462 let left_result = create_test_from_result(
463 "users",
464 vec![("id", DataType::Integer)],
465 vec![
466 vec![SqlValue::Integer(1)],
467 vec![SqlValue::Integer(2)],
468 vec![SqlValue::Integer(3)],
469 vec![SqlValue::Integer(4)],
470 vec![SqlValue::Integer(5)],
471 ],
472 );
473
474 let counting_iter = CountingIterator {
475 schema: left_result.schema.clone(),
476 rows: left_result.into_rows(),
477 index: 0,
478 consumed_count: consumed.clone(),
479 };
480
481 let right = create_test_from_result(
482 "orders",
483 vec![("user_id", DataType::Integer)],
484 vec![vec![SqlValue::Integer(1)], vec![SqlValue::Integer(2)]],
485 );
486
487 let join_iter = HashJoinIterator::new(counting_iter, right, 0, 0).unwrap();
488
489 let results: Vec<_> = join_iter.take(2).collect::<Result<Vec<_>, _>>().unwrap();
491 assert_eq!(results.len(), 2);
492
493 let consumed_count = *consumed.lock().unwrap();
496 assert!(consumed_count <= 3, "Expected at most 3 rows consumed, got {}", consumed_count);
497 }
498}