1use std::{
14 io::{self, Read},
15 sync::Arc,
16};
17
18use ahash::AHashMap;
19use vibesql_types::SqlValue;
20
21use super::{
22 controller::{MemoryController, MemoryReservation},
23 row_serialization::{deserialize_value, serialize_value},
24 spill::SpillFile,
25};
26
27#[derive(Debug, Clone)]
29pub struct ExternalHashJoinConfig {
30 pub num_partitions: usize,
32 pub max_rows_per_partition: usize,
34}
35
36impl Default for ExternalHashJoinConfig {
37 fn default() -> Self {
38 Self { num_partitions: 64, max_rows_per_partition: 10_000 }
39 }
40}
41
42#[derive(Debug, Clone, Copy, PartialEq, Eq)]
44pub enum JoinType {
45 Inner,
46 LeftOuter,
47 RightOuter,
48}
49
50struct BuildPartition {
52 rows: Vec<(Vec<SqlValue>, Vec<SqlValue>)>,
54 memory_bytes: usize,
56 spilled: bool,
58 spill_file: Option<SpillFile>,
60 spilled_row_count: usize,
62}
63
64impl BuildPartition {
65 fn new() -> Self {
66 Self {
67 rows: Vec::new(),
68 memory_bytes: 0,
69 spilled: false,
70 spill_file: None,
71 spilled_row_count: 0,
72 }
73 }
74}
75
76struct ProbePartition {
78 rows: Vec<(Vec<SqlValue>, Vec<SqlValue>)>,
80 memory_bytes: usize,
82 spilled: bool,
84 spill_file: Option<SpillFile>,
86 spilled_row_count: usize,
88}
89
90impl ProbePartition {
91 fn new() -> Self {
92 Self {
93 rows: Vec::new(),
94 memory_bytes: 0,
95 spilled: false,
96 spill_file: None,
97 spilled_row_count: 0,
98 }
99 }
100}
101
102pub struct ExternalHashJoin {
104 reservation: MemoryReservation,
106 #[allow(dead_code)]
108 config: ExternalHashJoinConfig,
109 build_partitions: Vec<BuildPartition>,
111 probe_partitions: Vec<ProbePartition>,
113 build_key_indices: Vec<usize>,
115 probe_key_indices: Vec<usize>,
117 partition_mask: usize,
119 join_type: JoinType,
121 build_row_count: usize,
123 probe_row_count: usize,
125 build_row_width: Option<usize>,
127 probe_row_width: Option<usize>,
129}
130
131impl ExternalHashJoin {
132 pub fn new(
134 controller: Arc<MemoryController>,
135 config: ExternalHashJoinConfig,
136 build_key_indices: Vec<usize>,
137 probe_key_indices: Vec<usize>,
138 join_type: JoinType,
139 ) -> Self {
140 assert!(config.num_partitions.is_power_of_two(), "num_partitions must be power of 2");
141 assert_eq!(
142 build_key_indices.len(),
143 probe_key_indices.len(),
144 "build and probe must have same number of key columns"
145 );
146
147 let partition_mask = config.num_partitions - 1;
148 let mut build_partitions = Vec::with_capacity(config.num_partitions);
149 let mut probe_partitions = Vec::with_capacity(config.num_partitions);
150
151 for _ in 0..config.num_partitions {
152 build_partitions.push(BuildPartition::new());
153 probe_partitions.push(ProbePartition::new());
154 }
155
156 Self {
157 reservation: controller.create_reservation(),
158 config,
159 build_partitions,
160 probe_partitions,
161 build_key_indices,
162 probe_key_indices,
163 partition_mask,
164 join_type,
165 build_row_count: 0,
166 probe_row_count: 0,
167 build_row_width: None,
168 probe_row_width: None,
169 }
170 }
171
172 pub fn add_build_row(&mut self, row: &[SqlValue]) -> io::Result<()> {
174 if self.build_row_width.is_none() {
176 self.build_row_width = Some(row.len());
177 }
178
179 let key_values: Vec<SqlValue> = self
180 .build_key_indices
181 .iter()
182 .map(|&idx| row.get(idx).cloned().unwrap_or(SqlValue::Null))
183 .collect();
184
185 if key_values.iter().any(|v| v == &SqlValue::Null) {
187 return Ok(());
188 }
189
190 let partition_idx = self.compute_partition(&key_values);
191 let row_size = estimate_row_size(&key_values) + estimate_row_size(row);
192
193 if self.reservation.should_spill() {
195 self.spill_largest_build_partition()?;
196 }
197
198 let partition = &mut self.build_partitions[partition_idx];
199
200 if partition.spilled {
201 self.write_to_build_spill(partition_idx, &key_values, row)?;
203 } else {
204 if !self.reservation.try_grow(row_size) {
206 self.spill_build_partition(partition_idx)?;
208 self.write_to_build_spill(partition_idx, &key_values, row)?;
209 } else {
210 partition.rows.push((key_values, row.to_vec()));
211 partition.memory_bytes += row_size;
212 }
213 }
214
215 self.build_row_count += 1;
216 Ok(())
217 }
218
219 pub fn add_probe_row(&mut self, row: &[SqlValue]) -> io::Result<()> {
221 if self.probe_row_width.is_none() {
223 self.probe_row_width = Some(row.len());
224 }
225
226 let key_values: Vec<SqlValue> = self
227 .probe_key_indices
228 .iter()
229 .map(|&idx| row.get(idx).cloned().unwrap_or(SqlValue::Null))
230 .collect();
231
232 if self.join_type == JoinType::Inner && key_values.iter().any(|v| v == &SqlValue::Null) {
234 return Ok(());
235 }
236
237 let partition_idx = self.compute_partition(&key_values);
238 let row_size = estimate_row_size(&key_values) + estimate_row_size(row);
239
240 if self.reservation.should_spill() {
242 self.spill_largest_probe_partition()?;
243 }
244
245 let partition = &mut self.probe_partitions[partition_idx];
246
247 if partition.spilled {
248 self.write_to_probe_spill(partition_idx, &key_values, row)?;
250 } else {
251 if !self.reservation.try_grow(row_size) {
253 self.spill_probe_partition(partition_idx)?;
255 self.write_to_probe_spill(partition_idx, &key_values, row)?;
256 } else {
257 partition.rows.push((key_values, row.to_vec()));
258 partition.memory_bytes += row_size;
259 }
260 }
261
262 self.probe_row_count += 1;
263 Ok(())
264 }
265
266 fn compute_partition(&self, key_values: &[SqlValue]) -> usize {
268 use std::hash::Hasher;
269 let mut hasher = ahash::AHasher::default();
270 for v in key_values {
271 hash_sql_value(v, &mut hasher);
272 }
273 (hasher.finish() as usize) & self.partition_mask
274 }
275
276 fn spill_largest_build_partition(&mut self) -> io::Result<()> {
278 let largest_idx = self
279 .build_partitions
280 .iter()
281 .enumerate()
282 .filter(|(_, p)| !p.spilled && !p.rows.is_empty())
283 .max_by_key(|(_, p)| p.memory_bytes)
284 .map(|(i, _)| i);
285
286 if let Some(idx) = largest_idx {
287 self.spill_build_partition(idx)?;
288 }
289 Ok(())
290 }
291
292 fn spill_build_partition(&mut self, idx: usize) -> io::Result<()> {
294 let partition = &mut self.build_partitions[idx];
295 if partition.spilled {
296 return Ok(());
297 }
298
299 let temp_dir = self.reservation.temp_directory().clone();
300 let mut spill_file = SpillFile::with_suffix(&temp_dir, &format!("build_part_{}", idx))?;
301
302 let rows = std::mem::take(&mut partition.rows);
304 for (key, row) in rows {
305 write_keyed_row(&mut spill_file, &key, &row)?;
306 }
307 spill_file.flush()?;
308
309 self.reservation.shrink(partition.memory_bytes);
310 partition.memory_bytes = 0;
311 partition.spilled = true;
312 partition.spill_file = Some(spill_file);
313
314 Ok(())
315 }
316
317 fn write_to_build_spill(
319 &mut self,
320 idx: usize,
321 key_values: &[SqlValue],
322 row: &[SqlValue],
323 ) -> io::Result<()> {
324 let partition = &mut self.build_partitions[idx];
325 let spill_file = partition.spill_file.as_mut().expect("spill file should exist");
326
327 write_keyed_row(spill_file, key_values, row)?;
328 spill_file.flush()?;
329
330 partition.spilled_row_count += 1;
331 Ok(())
332 }
333
334 fn spill_largest_probe_partition(&mut self) -> io::Result<()> {
336 let largest_idx = self
337 .probe_partitions
338 .iter()
339 .enumerate()
340 .filter(|(_, p)| !p.spilled && !p.rows.is_empty())
341 .max_by_key(|(_, p)| p.memory_bytes)
342 .map(|(i, _)| i);
343
344 if let Some(idx) = largest_idx {
345 self.spill_probe_partition(idx)?;
346 }
347 Ok(())
348 }
349
350 fn spill_probe_partition(&mut self, idx: usize) -> io::Result<()> {
352 let partition = &mut self.probe_partitions[idx];
353 if partition.spilled {
354 return Ok(());
355 }
356
357 let temp_dir = self.reservation.temp_directory().clone();
358 let mut spill_file = SpillFile::with_suffix(&temp_dir, &format!("probe_part_{}", idx))?;
359
360 let rows = std::mem::take(&mut partition.rows);
362 for (key, row) in rows {
363 write_keyed_row(&mut spill_file, &key, &row)?;
364 }
365 spill_file.flush()?;
366
367 self.reservation.shrink(partition.memory_bytes);
368 partition.memory_bytes = 0;
369 partition.spilled = true;
370 partition.spill_file = Some(spill_file);
371
372 Ok(())
373 }
374
375 fn write_to_probe_spill(
377 &mut self,
378 idx: usize,
379 key_values: &[SqlValue],
380 row: &[SqlValue],
381 ) -> io::Result<()> {
382 let partition = &mut self.probe_partitions[idx];
383 let spill_file = partition.spill_file.as_mut().expect("spill file should exist");
384
385 write_keyed_row(spill_file, key_values, row)?;
386 spill_file.flush()?;
387
388 partition.spilled_row_count += 1;
389 Ok(())
390 }
391
392 pub fn num_spilled_build_partitions(&self) -> usize {
394 self.build_partitions.iter().filter(|p| p.spilled).count()
395 }
396
397 pub fn num_spilled_probe_partitions(&self) -> usize {
399 self.probe_partitions.iter().filter(|p| p.spilled).count()
400 }
401
402 pub fn finish(mut self) -> io::Result<HashJoinResultIterator> {
404 let mut results = Vec::new();
405
406 for partition_idx in 0..self.build_partitions.len() {
408 let partition_results = self.process_partition(partition_idx)?;
409 results.extend(partition_results);
410 }
411
412 Ok(HashJoinResultIterator { results: results.into_iter(), _reservation: self.reservation })
413 }
414
415 fn process_partition(&mut self, partition_idx: usize) -> io::Result<Vec<Vec<SqlValue>>> {
417 let build_rows = self.load_build_partition(partition_idx)?;
419
420 let mut hash_table: AHashMap<Vec<SqlValue>, Vec<Vec<SqlValue>>> = AHashMap::new();
422 for (key, row) in build_rows {
423 hash_table.entry(key).or_default().push(row);
424 }
425
426 let probe_rows = self.load_probe_partition(partition_idx)?;
428
429 let mut results = Vec::new();
431
432 match self.join_type {
433 JoinType::Inner => {
434 for (key, probe_row) in probe_rows {
435 if let Some(build_rows) = hash_table.get(&key) {
436 for build_row in build_rows {
437 let mut result = build_row.clone();
438 result.extend(probe_row.clone());
439 results.push(result);
440 }
441 }
442 }
443 }
444 JoinType::LeftOuter => {
445 let build_row_width = self.build_row_width.unwrap_or_else(|| {
447 hash_table.values().next().and_then(|v| v.first()).map(|r| r.len()).unwrap_or(0)
448 });
449
450 for (key, probe_row) in probe_rows {
451 if let Some(build_rows) = hash_table.get(&key) {
452 for build_row in build_rows {
453 let mut result = build_row.clone();
454 result.extend(probe_row.clone());
455 results.push(result);
456 }
457 } else {
458 let mut result = vec![SqlValue::Null; build_row_width];
460 result.extend(probe_row);
461 results.push(result);
462 }
463 }
464 }
465 JoinType::RightOuter => {
466 let mut matched: std::collections::HashSet<usize> =
468 std::collections::HashSet::new();
469 let probe_row_width = self
471 .probe_row_width
472 .unwrap_or_else(|| probe_rows.first().map(|(_, r)| r.len()).unwrap_or(0));
473
474 let build_rows_vec: Vec<_> = hash_table
476 .iter()
477 .flat_map(|(k, rows)| rows.iter().map(move |r| (k.clone(), r.clone())))
478 .collect();
479
480 for (key, probe_row) in &probe_rows {
481 if let Some(build_rows) = hash_table.get(key) {
482 for build_row in build_rows {
483 let mut result = build_row.clone();
484 result.extend(probe_row.clone());
485 results.push(result);
486 }
487 for (i, (k, _)) in build_rows_vec.iter().enumerate() {
489 if k == key {
490 matched.insert(i);
491 }
492 }
493 }
494 }
495
496 for (i, (_, build_row)) in build_rows_vec.iter().enumerate() {
498 if !matched.contains(&i) {
499 let mut result = build_row.clone();
500 result.extend(vec![SqlValue::Null; probe_row_width]);
501 results.push(result);
502 }
503 }
504 }
505 }
506
507 Ok(results)
508 }
509
510 fn load_build_partition(
512 &mut self,
513 idx: usize,
514 ) -> io::Result<Vec<(Vec<SqlValue>, Vec<SqlValue>)>> {
515 let partition = &mut self.build_partitions[idx];
516
517 if !partition.spilled {
518 return Ok(std::mem::take(&mut partition.rows));
520 }
521
522 let spill_file = partition.spill_file.as_mut().expect("spill file should exist");
524 spill_file.prepare_for_read()?;
525
526 let mut rows = Vec::new();
527 loop {
528 match read_keyed_row(spill_file) {
529 Ok(Some((key, row))) => rows.push((key, row)),
530 Ok(None) => break,
531 Err(e) => return Err(e),
532 }
533 }
534
535 Ok(rows)
536 }
537
538 fn load_probe_partition(
540 &mut self,
541 idx: usize,
542 ) -> io::Result<Vec<(Vec<SqlValue>, Vec<SqlValue>)>> {
543 let partition = &mut self.probe_partitions[idx];
544
545 if !partition.spilled {
546 return Ok(std::mem::take(&mut partition.rows));
548 }
549
550 let spill_file = partition.spill_file.as_mut().expect("spill file should exist");
552 spill_file.prepare_for_read()?;
553
554 let mut rows = Vec::new();
555 loop {
556 match read_keyed_row(spill_file) {
557 Ok(Some((key, row))) => rows.push((key, row)),
558 Ok(None) => break,
559 Err(e) => return Err(e),
560 }
561 }
562
563 Ok(rows)
564 }
565}
566
567fn write_keyed_row(
569 spill_file: &mut SpillFile,
570 key: &[SqlValue],
571 row: &[SqlValue],
572) -> io::Result<()> {
573 let key_len = key.len() as u16;
575 spill_file.write_all(&key_len.to_le_bytes())?;
576
577 let mut buf = Vec::new();
579 for v in key {
580 serialize_value(v, &mut buf)?;
581 }
582 spill_file.write_all(&buf)?;
583
584 let row_len = row.len() as u16;
586 spill_file.write_all(&row_len.to_le_bytes())?;
587
588 buf.clear();
590 for v in row {
591 serialize_value(v, &mut buf)?;
592 }
593 spill_file.write_all(&buf)?;
594
595 Ok(())
596}
597
598fn read_keyed_row(
600 spill_file: &mut SpillFile,
601) -> io::Result<Option<(Vec<SqlValue>, Vec<SqlValue>)>> {
602 let mut len_buf = [0u8; 2];
604 match spill_file.read_exact(&mut len_buf) {
605 Ok(()) => {}
606 Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => return Ok(None),
607 Err(e) => return Err(e),
608 }
609 let key_len = u16::from_le_bytes(len_buf) as usize;
610
611 let mut key = Vec::with_capacity(key_len);
613 for _ in 0..key_len {
614 key.push(deserialize_value_from_spill(spill_file)?);
615 }
616
617 spill_file.read_exact(&mut len_buf)?;
619 let row_len = u16::from_le_bytes(len_buf) as usize;
620
621 let mut row = Vec::with_capacity(row_len);
623 for _ in 0..row_len {
624 row.push(deserialize_value_from_spill(spill_file)?);
625 }
626
627 Ok(Some((key, row)))
628}
629
630fn deserialize_value_from_spill(spill_file: &mut SpillFile) -> io::Result<SqlValue> {
632 struct SpillFileReader<'a>(&'a mut SpillFile);
634
635 impl Read for SpillFileReader<'_> {
636 fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
637 self.0.read(buf)
638 }
639 }
640
641 let mut reader = SpillFileReader(spill_file);
642 deserialize_value(&mut reader)
643}
644
645fn hash_sql_value<H: std::hash::Hasher>(value: &SqlValue, hasher: &mut H) {
647 use std::hash::Hash;
648 match value {
649 SqlValue::Null => 0u8.hash(hasher),
650 SqlValue::Integer(i) => {
651 1u8.hash(hasher);
652 i.hash(hasher);
653 }
654 SqlValue::Smallint(i) => {
655 2u8.hash(hasher);
656 i.hash(hasher);
657 }
658 SqlValue::Bigint(i) => {
659 3u8.hash(hasher);
660 i.hash(hasher);
661 }
662 SqlValue::Unsigned(u) => {
663 4u8.hash(hasher);
664 u.hash(hasher);
665 }
666 SqlValue::Numeric(d) => {
667 5u8.hash(hasher);
668 d.to_string().hash(hasher);
669 }
670 SqlValue::Float(f) => {
671 6u8.hash(hasher);
672 f.to_bits().hash(hasher);
673 }
674 SqlValue::Real(f) => {
675 7u8.hash(hasher);
676 f.to_bits().hash(hasher);
677 }
678 SqlValue::Double(f) => {
679 8u8.hash(hasher);
680 f.to_bits().hash(hasher);
681 }
682 SqlValue::Character(s) | SqlValue::Varchar(s) => {
683 9u8.hash(hasher);
684 s.hash(hasher);
685 }
686 SqlValue::Boolean(b) => {
687 10u8.hash(hasher);
688 b.hash(hasher);
689 }
690 SqlValue::Date(d) => {
691 11u8.hash(hasher);
692 d.hash(hasher);
693 }
694 SqlValue::Time(t) => {
695 12u8.hash(hasher);
696 t.hash(hasher);
697 }
698 SqlValue::Timestamp(ts) => {
699 13u8.hash(hasher);
700 ts.hash(hasher);
701 }
702 SqlValue::Interval(iv) => {
703 14u8.hash(hasher);
704 iv.value.hash(hasher);
705 }
706 SqlValue::Vector(v) => {
707 15u8.hash(hasher);
708 for f in v {
709 f.to_bits().hash(hasher);
710 }
711 }
712 SqlValue::Blob(b) => {
713 16u8.hash(hasher);
714 b.hash(hasher);
715 }
716 }
717}
718
719fn estimate_row_size(row: &[SqlValue]) -> usize {
721 let base_size = std::mem::size_of::<Vec<SqlValue>>() + std::mem::size_of_val(row);
722 let value_size: usize = row
723 .iter()
724 .map(|v| match v {
725 SqlValue::Varchar(s) | SqlValue::Character(s) => s.len(),
726 SqlValue::Vector(vec) => vec.len() * std::mem::size_of::<f32>(),
727 _ => 0,
728 })
729 .sum();
730 base_size + value_size
731}
732
733pub struct HashJoinResultIterator {
735 results: std::vec::IntoIter<Vec<SqlValue>>,
736 #[allow(dead_code)]
737 _reservation: MemoryReservation,
738}
739
740impl Iterator for HashJoinResultIterator {
741 type Item = io::Result<Vec<SqlValue>>;
742
743 fn next(&mut self) -> Option<Self::Item> {
744 self.results.next().map(Ok)
745 }
746}
747
748#[cfg(test)]
749mod tests {
750 use super::*;
751
752 fn make_test_controller() -> Arc<MemoryController> {
753 Arc::new(MemoryController::with_budget(1024 * 1024)) }
755
756 #[test]
757 fn test_inner_join_basic() {
758 let controller = make_test_controller();
759 let config = ExternalHashJoinConfig { num_partitions: 4, max_rows_per_partition: 100 };
760
761 let mut join = ExternalHashJoin::new(
762 controller,
763 config,
764 vec![0], vec![0], JoinType::Inner,
767 );
768
769 join.add_build_row(&[SqlValue::Integer(1), SqlValue::Varchar("Alice".into())]).unwrap();
771 join.add_build_row(&[SqlValue::Integer(2), SqlValue::Varchar("Bob".into())]).unwrap();
772 join.add_build_row(&[SqlValue::Integer(3), SqlValue::Varchar("Charlie".into())]).unwrap();
773
774 join.add_probe_row(&[SqlValue::Integer(1), SqlValue::Varchar("NYC".into())]).unwrap();
776 join.add_probe_row(&[SqlValue::Integer(2), SqlValue::Varchar("LA".into())]).unwrap();
777 join.add_probe_row(&[SqlValue::Integer(4), SqlValue::Varchar("Chicago".into())]).unwrap(); let results: Vec<_> = join.finish().unwrap().map(|r| r.unwrap()).collect();
780
781 assert_eq!(results.len(), 2);
783
784 let alice_match = results.iter().find(|r| r[1] == SqlValue::Varchar("Alice".into()));
786 assert!(alice_match.is_some());
787 assert_eq!(alice_match.unwrap()[3], SqlValue::Varchar("NYC".into()));
788
789 let bob_match = results.iter().find(|r| r[1] == SqlValue::Varchar("Bob".into()));
791 assert!(bob_match.is_some());
792 assert_eq!(bob_match.unwrap()[3], SqlValue::Varchar("LA".into()));
793 }
794
795 #[test]
796 fn test_left_outer_join() {
797 let controller = make_test_controller();
798 let config = ExternalHashJoinConfig { num_partitions: 4, max_rows_per_partition: 100 };
799
800 let mut join =
801 ExternalHashJoin::new(controller, config, vec![0], vec![0], JoinType::LeftOuter);
802
803 join.add_build_row(&[SqlValue::Integer(1), SqlValue::Varchar("A".into())]).unwrap();
805
806 join.add_probe_row(&[SqlValue::Integer(1), SqlValue::Varchar("X".into())]).unwrap();
808 join.add_probe_row(&[SqlValue::Integer(2), SqlValue::Varchar("Y".into())]).unwrap(); let results: Vec<_> = join.finish().unwrap().map(|r| r.unwrap()).collect();
811
812 assert_eq!(results.len(), 2);
814
815 let matched = results.iter().filter(|r| r[0] != SqlValue::Null).count();
817 let unmatched = results.iter().filter(|r| r[0] == SqlValue::Null).count();
818 assert_eq!(matched, 1);
819 assert_eq!(unmatched, 1);
820 }
821
822 #[test]
823 fn test_multi_key_join() {
824 let controller = make_test_controller();
825 let config = ExternalHashJoinConfig { num_partitions: 4, max_rows_per_partition: 100 };
826
827 let mut join = ExternalHashJoin::new(
828 controller,
829 config,
830 vec![0, 1], vec![0, 1],
832 JoinType::Inner,
833 );
834
835 join.add_build_row(&[
837 SqlValue::Integer(1),
838 SqlValue::Integer(10),
839 SqlValue::Varchar("X".into()),
840 ])
841 .unwrap();
842 join.add_build_row(&[
843 SqlValue::Integer(1),
844 SqlValue::Integer(20),
845 SqlValue::Varchar("Y".into()),
846 ])
847 .unwrap();
848
849 join.add_probe_row(&[
851 SqlValue::Integer(1),
852 SqlValue::Integer(10),
853 SqlValue::Varchar("P".into()),
854 ])
855 .unwrap();
856 join.add_probe_row(&[
857 SqlValue::Integer(1),
858 SqlValue::Integer(30),
859 SqlValue::Varchar("Q".into()),
860 ])
861 .unwrap(); let results: Vec<_> = join.finish().unwrap().map(|r| r.unwrap()).collect();
864
865 assert_eq!(results.len(), 1);
867 assert_eq!(results[0][2], SqlValue::Varchar("X".into()));
868 assert_eq!(results[0][5], SqlValue::Varchar("P".into()));
869 }
870
871 #[test]
872 fn test_null_handling() {
873 let controller = make_test_controller();
874 let config = ExternalHashJoinConfig { num_partitions: 4, max_rows_per_partition: 100 };
875
876 let mut join = ExternalHashJoin::new(controller, config, vec![0], vec![0], JoinType::Inner);
877
878 join.add_build_row(&[SqlValue::Null, SqlValue::Varchar("A".into())]).unwrap();
880 join.add_build_row(&[SqlValue::Integer(1), SqlValue::Varchar("B".into())]).unwrap();
881
882 join.add_probe_row(&[SqlValue::Null, SqlValue::Varchar("X".into())]).unwrap();
884 join.add_probe_row(&[SqlValue::Integer(1), SqlValue::Varchar("Y".into())]).unwrap();
885
886 let results: Vec<_> = join.finish().unwrap().map(|r| r.unwrap()).collect();
887
888 assert_eq!(results.len(), 1);
890 assert_eq!(results[0][1], SqlValue::Varchar("B".into()));
891 }
892
893 #[test]
894 fn test_empty_inputs() {
895 let controller = make_test_controller();
896 let config = ExternalHashJoinConfig::default();
897
898 let join = ExternalHashJoin::new(controller, config, vec![0], vec![0], JoinType::Inner);
899
900 let results: Vec<_> = join.finish().unwrap().map(|r| r.unwrap()).collect();
901 assert!(results.is_empty());
902 }
903
904 #[test]
905 fn test_many_partitions() {
906 let controller = make_test_controller();
907 let config = ExternalHashJoinConfig { num_partitions: 16, max_rows_per_partition: 10 };
908
909 let mut join = ExternalHashJoin::new(controller, config, vec![0], vec![0], JoinType::Inner);
910
911 for i in 0..100 {
913 join.add_build_row(&[SqlValue::Integer(i), SqlValue::Integer(i * 10)]).unwrap();
914 join.add_probe_row(&[SqlValue::Integer(i), SqlValue::Integer(i * 100)]).unwrap();
915 }
916
917 let results: Vec<_> = join.finish().unwrap().map(|r| r.unwrap()).collect();
918
919 assert_eq!(results.len(), 100);
921 }
922
923 #[test]
924 fn test_duplicate_keys() {
925 let controller = make_test_controller();
926 let config = ExternalHashJoinConfig { num_partitions: 4, max_rows_per_partition: 100 };
927
928 let mut join = ExternalHashJoin::new(controller, config, vec![0], vec![0], JoinType::Inner);
929
930 join.add_build_row(&[SqlValue::Integer(1), SqlValue::Varchar("A".into())]).unwrap();
932 join.add_build_row(&[SqlValue::Integer(1), SqlValue::Varchar("B".into())]).unwrap();
933
934 join.add_probe_row(&[SqlValue::Integer(1), SqlValue::Varchar("X".into())]).unwrap();
936 join.add_probe_row(&[SqlValue::Integer(1), SqlValue::Varchar("Y".into())]).unwrap();
937
938 let results: Vec<_> = join.finish().unwrap().map(|r| r.unwrap()).collect();
939
940 assert_eq!(results.len(), 4);
942 }
943}