vibesql_executor/memory/
external_hash_join.rs

1//! External Hash Join (Grace Hash Join) with partition-based disk spilling
2//!
3//! This module implements a memory-bounded hash join that can handle datasets
4//! larger than available memory by using a partition-based approach:
5//!
6//! 1. **Partition Phase**: Both build and probe sides are partitioned by hash
7//! 2. **Build Phase**: For each partition, build an in-memory hash table
8//! 3. **Probe Phase**: Probe the hash table with matching partition rows
9//!
10//! When memory is exhausted, partitions are spilled to disk and processed
11//! one at a time during the final join phase.
12
13use std::{
14    io::{self, Read},
15    sync::Arc,
16};
17
18use ahash::AHashMap;
19use vibesql_types::SqlValue;
20
21use super::{
22    controller::{MemoryController, MemoryReservation},
23    row_serialization::{deserialize_value, serialize_value},
24    spill::SpillFile,
25};
26
27/// Configuration for external hash join
28#[derive(Debug, Clone)]
29pub struct ExternalHashJoinConfig {
30    /// Number of partitions (must be power of 2)
31    pub num_partitions: usize,
32    /// Maximum rows to keep in memory per partition before considering spill
33    pub max_rows_per_partition: usize,
34}
35
36impl Default for ExternalHashJoinConfig {
37    fn default() -> Self {
38        Self { num_partitions: 64, max_rows_per_partition: 10_000 }
39    }
40}
41
42/// Join type for the external hash join
43#[derive(Debug, Clone, Copy, PartialEq, Eq)]
44pub enum JoinType {
45    Inner,
46    LeftOuter,
47    RightOuter,
48}
49
50/// Build-side partition
51struct BuildPartition {
52    /// In-memory rows: (key_values, full_row)
53    rows: Vec<(Vec<SqlValue>, Vec<SqlValue>)>,
54    /// Estimated memory usage
55    memory_bytes: usize,
56    /// Whether this partition has been spilled
57    spilled: bool,
58    /// Spill file for this partition
59    spill_file: Option<SpillFile>,
60    /// Number of rows written to spill file
61    spilled_row_count: usize,
62}
63
64impl BuildPartition {
65    fn new() -> Self {
66        Self {
67            rows: Vec::new(),
68            memory_bytes: 0,
69            spilled: false,
70            spill_file: None,
71            spilled_row_count: 0,
72        }
73    }
74}
75
76/// Probe-side partition
77struct ProbePartition {
78    /// In-memory rows: (key_values, full_row)
79    rows: Vec<(Vec<SqlValue>, Vec<SqlValue>)>,
80    /// Estimated memory usage
81    memory_bytes: usize,
82    /// Whether this partition has been spilled
83    spilled: bool,
84    /// Spill file for this partition
85    spill_file: Option<SpillFile>,
86    /// Number of rows written to spill file
87    spilled_row_count: usize,
88}
89
90impl ProbePartition {
91    fn new() -> Self {
92        Self {
93            rows: Vec::new(),
94            memory_bytes: 0,
95            spilled: false,
96            spill_file: None,
97            spilled_row_count: 0,
98        }
99    }
100}
101
102/// External Hash Join operator
103pub struct ExternalHashJoin {
104    /// Memory reservation
105    reservation: MemoryReservation,
106    /// Configuration
107    #[allow(dead_code)]
108    config: ExternalHashJoinConfig,
109    /// Build-side partitions
110    build_partitions: Vec<BuildPartition>,
111    /// Probe-side partitions
112    probe_partitions: Vec<ProbePartition>,
113    /// Join key column indices for build side
114    build_key_indices: Vec<usize>,
115    /// Join key column indices for probe side
116    probe_key_indices: Vec<usize>,
117    /// Partition mask (num_partitions - 1)
118    partition_mask: usize,
119    /// Join type
120    join_type: JoinType,
121    /// Total rows on build side
122    build_row_count: usize,
123    /// Total rows on probe side
124    probe_row_count: usize,
125    /// Width of build rows (for outer joins)
126    build_row_width: Option<usize>,
127    /// Width of probe rows (for outer joins)
128    probe_row_width: Option<usize>,
129}
130
131impl ExternalHashJoin {
132    /// Create a new external hash join
133    pub fn new(
134        controller: Arc<MemoryController>,
135        config: ExternalHashJoinConfig,
136        build_key_indices: Vec<usize>,
137        probe_key_indices: Vec<usize>,
138        join_type: JoinType,
139    ) -> Self {
140        assert!(config.num_partitions.is_power_of_two(), "num_partitions must be power of 2");
141        assert_eq!(
142            build_key_indices.len(),
143            probe_key_indices.len(),
144            "build and probe must have same number of key columns"
145        );
146
147        let partition_mask = config.num_partitions - 1;
148        let mut build_partitions = Vec::with_capacity(config.num_partitions);
149        let mut probe_partitions = Vec::with_capacity(config.num_partitions);
150
151        for _ in 0..config.num_partitions {
152            build_partitions.push(BuildPartition::new());
153            probe_partitions.push(ProbePartition::new());
154        }
155
156        Self {
157            reservation: controller.create_reservation(),
158            config,
159            build_partitions,
160            probe_partitions,
161            build_key_indices,
162            probe_key_indices,
163            partition_mask,
164            join_type,
165            build_row_count: 0,
166            probe_row_count: 0,
167            build_row_width: None,
168            probe_row_width: None,
169        }
170    }
171
172    /// Add a row to the build side
173    pub fn add_build_row(&mut self, row: &[SqlValue]) -> io::Result<()> {
174        // Capture row width on first row (for outer joins)
175        if self.build_row_width.is_none() {
176            self.build_row_width = Some(row.len());
177        }
178
179        let key_values: Vec<SqlValue> = self
180            .build_key_indices
181            .iter()
182            .map(|&idx| row.get(idx).cloned().unwrap_or(SqlValue::Null))
183            .collect();
184
185        // Skip rows with NULL keys (they never match in equi-joins)
186        if key_values.iter().any(|v| v == &SqlValue::Null) {
187            return Ok(());
188        }
189
190        let partition_idx = self.compute_partition(&key_values);
191        let row_size = estimate_row_size(&key_values) + estimate_row_size(row);
192
193        // Check if we need to spill
194        if self.reservation.should_spill() {
195            self.spill_largest_build_partition()?;
196        }
197
198        let partition = &mut self.build_partitions[partition_idx];
199
200        if partition.spilled {
201            // Write directly to spill file
202            self.write_to_build_spill(partition_idx, &key_values, row)?;
203        } else {
204            // Try to grow memory reservation
205            if !self.reservation.try_grow(row_size) {
206                // Need to spill this partition
207                self.spill_build_partition(partition_idx)?;
208                self.write_to_build_spill(partition_idx, &key_values, row)?;
209            } else {
210                partition.rows.push((key_values, row.to_vec()));
211                partition.memory_bytes += row_size;
212            }
213        }
214
215        self.build_row_count += 1;
216        Ok(())
217    }
218
219    /// Add a row to the probe side
220    pub fn add_probe_row(&mut self, row: &[SqlValue]) -> io::Result<()> {
221        // Capture row width on first row (for outer joins)
222        if self.probe_row_width.is_none() {
223            self.probe_row_width = Some(row.len());
224        }
225
226        let key_values: Vec<SqlValue> = self
227            .probe_key_indices
228            .iter()
229            .map(|&idx| row.get(idx).cloned().unwrap_or(SqlValue::Null))
230            .collect();
231
232        // For inner joins, skip NULL keys. For outer joins, we need them.
233        if self.join_type == JoinType::Inner && key_values.iter().any(|v| v == &SqlValue::Null) {
234            return Ok(());
235        }
236
237        let partition_idx = self.compute_partition(&key_values);
238        let row_size = estimate_row_size(&key_values) + estimate_row_size(row);
239
240        // Check if we need to spill
241        if self.reservation.should_spill() {
242            self.spill_largest_probe_partition()?;
243        }
244
245        let partition = &mut self.probe_partitions[partition_idx];
246
247        if partition.spilled {
248            // Write directly to spill file
249            self.write_to_probe_spill(partition_idx, &key_values, row)?;
250        } else {
251            // Try to grow memory reservation
252            if !self.reservation.try_grow(row_size) {
253                // Need to spill this partition
254                self.spill_probe_partition(partition_idx)?;
255                self.write_to_probe_spill(partition_idx, &key_values, row)?;
256            } else {
257                partition.rows.push((key_values, row.to_vec()));
258                partition.memory_bytes += row_size;
259            }
260        }
261
262        self.probe_row_count += 1;
263        Ok(())
264    }
265
266    /// Compute partition index from key values
267    fn compute_partition(&self, key_values: &[SqlValue]) -> usize {
268        use std::hash::Hasher;
269        let mut hasher = ahash::AHasher::default();
270        for v in key_values {
271            hash_sql_value(v, &mut hasher);
272        }
273        (hasher.finish() as usize) & self.partition_mask
274    }
275
276    /// Spill the largest build partition to disk
277    fn spill_largest_build_partition(&mut self) -> io::Result<()> {
278        let largest_idx = self
279            .build_partitions
280            .iter()
281            .enumerate()
282            .filter(|(_, p)| !p.spilled && !p.rows.is_empty())
283            .max_by_key(|(_, p)| p.memory_bytes)
284            .map(|(i, _)| i);
285
286        if let Some(idx) = largest_idx {
287            self.spill_build_partition(idx)?;
288        }
289        Ok(())
290    }
291
292    /// Spill a specific build partition
293    fn spill_build_partition(&mut self, idx: usize) -> io::Result<()> {
294        let partition = &mut self.build_partitions[idx];
295        if partition.spilled {
296            return Ok(());
297        }
298
299        let temp_dir = self.reservation.temp_directory().clone();
300        let mut spill_file = SpillFile::with_suffix(&temp_dir, &format!("build_part_{}", idx))?;
301
302        // Write existing rows to spill file
303        let rows = std::mem::take(&mut partition.rows);
304        for (key, row) in rows {
305            write_keyed_row(&mut spill_file, &key, &row)?;
306        }
307        spill_file.flush()?;
308
309        self.reservation.shrink(partition.memory_bytes);
310        partition.memory_bytes = 0;
311        partition.spilled = true;
312        partition.spill_file = Some(spill_file);
313
314        Ok(())
315    }
316
317    /// Write a row to build spill file
318    fn write_to_build_spill(
319        &mut self,
320        idx: usize,
321        key_values: &[SqlValue],
322        row: &[SqlValue],
323    ) -> io::Result<()> {
324        let partition = &mut self.build_partitions[idx];
325        let spill_file = partition.spill_file.as_mut().expect("spill file should exist");
326
327        write_keyed_row(spill_file, key_values, row)?;
328        spill_file.flush()?;
329
330        partition.spilled_row_count += 1;
331        Ok(())
332    }
333
334    /// Spill the largest probe partition to disk
335    fn spill_largest_probe_partition(&mut self) -> io::Result<()> {
336        let largest_idx = self
337            .probe_partitions
338            .iter()
339            .enumerate()
340            .filter(|(_, p)| !p.spilled && !p.rows.is_empty())
341            .max_by_key(|(_, p)| p.memory_bytes)
342            .map(|(i, _)| i);
343
344        if let Some(idx) = largest_idx {
345            self.spill_probe_partition(idx)?;
346        }
347        Ok(())
348    }
349
350    /// Spill a specific probe partition
351    fn spill_probe_partition(&mut self, idx: usize) -> io::Result<()> {
352        let partition = &mut self.probe_partitions[idx];
353        if partition.spilled {
354            return Ok(());
355        }
356
357        let temp_dir = self.reservation.temp_directory().clone();
358        let mut spill_file = SpillFile::with_suffix(&temp_dir, &format!("probe_part_{}", idx))?;
359
360        // Write existing rows to spill file
361        let rows = std::mem::take(&mut partition.rows);
362        for (key, row) in rows {
363            write_keyed_row(&mut spill_file, &key, &row)?;
364        }
365        spill_file.flush()?;
366
367        self.reservation.shrink(partition.memory_bytes);
368        partition.memory_bytes = 0;
369        partition.spilled = true;
370        partition.spill_file = Some(spill_file);
371
372        Ok(())
373    }
374
375    /// Write a row to probe spill file
376    fn write_to_probe_spill(
377        &mut self,
378        idx: usize,
379        key_values: &[SqlValue],
380        row: &[SqlValue],
381    ) -> io::Result<()> {
382        let partition = &mut self.probe_partitions[idx];
383        let spill_file = partition.spill_file.as_mut().expect("spill file should exist");
384
385        write_keyed_row(spill_file, key_values, row)?;
386        spill_file.flush()?;
387
388        partition.spilled_row_count += 1;
389        Ok(())
390    }
391
392    /// Get number of spilled build partitions
393    pub fn num_spilled_build_partitions(&self) -> usize {
394        self.build_partitions.iter().filter(|p| p.spilled).count()
395    }
396
397    /// Get number of spilled probe partitions
398    pub fn num_spilled_probe_partitions(&self) -> usize {
399        self.probe_partitions.iter().filter(|p| p.spilled).count()
400    }
401
402    /// Execute the join and return results
403    pub fn finish(mut self) -> io::Result<HashJoinResultIterator> {
404        let mut results = Vec::new();
405
406        // Process each partition
407        for partition_idx in 0..self.build_partitions.len() {
408            let partition_results = self.process_partition(partition_idx)?;
409            results.extend(partition_results);
410        }
411
412        Ok(HashJoinResultIterator { results: results.into_iter(), _reservation: self.reservation })
413    }
414
415    /// Process a single partition
416    fn process_partition(&mut self, partition_idx: usize) -> io::Result<Vec<Vec<SqlValue>>> {
417        // Load build side into hash table
418        let build_rows = self.load_build_partition(partition_idx)?;
419
420        // Build hash table
421        let mut hash_table: AHashMap<Vec<SqlValue>, Vec<Vec<SqlValue>>> = AHashMap::new();
422        for (key, row) in build_rows {
423            hash_table.entry(key).or_default().push(row);
424        }
425
426        // Load probe side
427        let probe_rows = self.load_probe_partition(partition_idx)?;
428
429        // Perform join
430        let mut results = Vec::new();
431
432        match self.join_type {
433            JoinType::Inner => {
434                for (key, probe_row) in probe_rows {
435                    if let Some(build_rows) = hash_table.get(&key) {
436                        for build_row in build_rows {
437                            let mut result = build_row.clone();
438                            result.extend(probe_row.clone());
439                            results.push(result);
440                        }
441                    }
442                }
443            }
444            JoinType::LeftOuter => {
445                // Use stored width or fall back to hash table if available
446                let build_row_width = self.build_row_width.unwrap_or_else(|| {
447                    hash_table.values().next().and_then(|v| v.first()).map(|r| r.len()).unwrap_or(0)
448                });
449
450                for (key, probe_row) in probe_rows {
451                    if let Some(build_rows) = hash_table.get(&key) {
452                        for build_row in build_rows {
453                            let mut result = build_row.clone();
454                            result.extend(probe_row.clone());
455                            results.push(result);
456                        }
457                    } else {
458                        // No match - emit probe row with NULLs for build side
459                        let mut result = vec![SqlValue::Null; build_row_width];
460                        result.extend(probe_row);
461                        results.push(result);
462                    }
463                }
464            }
465            JoinType::RightOuter => {
466                // Track which build rows were matched
467                let mut matched: std::collections::HashSet<usize> =
468                    std::collections::HashSet::new();
469                // Use stored width or fall back to probe rows if available
470                let probe_row_width = self
471                    .probe_row_width
472                    .unwrap_or_else(|| probe_rows.first().map(|(_, r)| r.len()).unwrap_or(0));
473
474                // First pass: find matches
475                let build_rows_vec: Vec<_> = hash_table
476                    .iter()
477                    .flat_map(|(k, rows)| rows.iter().map(move |r| (k.clone(), r.clone())))
478                    .collect();
479
480                for (key, probe_row) in &probe_rows {
481                    if let Some(build_rows) = hash_table.get(key) {
482                        for build_row in build_rows {
483                            let mut result = build_row.clone();
484                            result.extend(probe_row.clone());
485                            results.push(result);
486                        }
487                        // Mark these build rows as matched
488                        for (i, (k, _)) in build_rows_vec.iter().enumerate() {
489                            if k == key {
490                                matched.insert(i);
491                            }
492                        }
493                    }
494                }
495
496                // Second pass: emit unmatched build rows with NULLs
497                for (i, (_, build_row)) in build_rows_vec.iter().enumerate() {
498                    if !matched.contains(&i) {
499                        let mut result = build_row.clone();
500                        result.extend(vec![SqlValue::Null; probe_row_width]);
501                        results.push(result);
502                    }
503                }
504            }
505        }
506
507        Ok(results)
508    }
509
510    /// Load build partition (from memory or disk)
511    fn load_build_partition(
512        &mut self,
513        idx: usize,
514    ) -> io::Result<Vec<(Vec<SqlValue>, Vec<SqlValue>)>> {
515        let partition = &mut self.build_partitions[idx];
516
517        if !partition.spilled {
518            // Return in-memory rows
519            return Ok(std::mem::take(&mut partition.rows));
520        }
521
522        // Read from spill file
523        let spill_file = partition.spill_file.as_mut().expect("spill file should exist");
524        spill_file.prepare_for_read()?;
525
526        let mut rows = Vec::new();
527        loop {
528            match read_keyed_row(spill_file) {
529                Ok(Some((key, row))) => rows.push((key, row)),
530                Ok(None) => break,
531                Err(e) => return Err(e),
532            }
533        }
534
535        Ok(rows)
536    }
537
538    /// Load probe partition (from memory or disk)
539    fn load_probe_partition(
540        &mut self,
541        idx: usize,
542    ) -> io::Result<Vec<(Vec<SqlValue>, Vec<SqlValue>)>> {
543        let partition = &mut self.probe_partitions[idx];
544
545        if !partition.spilled {
546            // Return in-memory rows
547            return Ok(std::mem::take(&mut partition.rows));
548        }
549
550        // Read from spill file
551        let spill_file = partition.spill_file.as_mut().expect("spill file should exist");
552        spill_file.prepare_for_read()?;
553
554        let mut rows = Vec::new();
555        loop {
556            match read_keyed_row(spill_file) {
557                Ok(Some((key, row))) => rows.push((key, row)),
558                Ok(None) => break,
559                Err(e) => return Err(e),
560            }
561        }
562
563        Ok(rows)
564    }
565}
566
567/// Write a keyed row to a spill file
568fn write_keyed_row(
569    spill_file: &mut SpillFile,
570    key: &[SqlValue],
571    row: &[SqlValue],
572) -> io::Result<()> {
573    // Write key length
574    let key_len = key.len() as u16;
575    spill_file.write_all(&key_len.to_le_bytes())?;
576
577    // Write key values
578    let mut buf = Vec::new();
579    for v in key {
580        serialize_value(v, &mut buf)?;
581    }
582    spill_file.write_all(&buf)?;
583
584    // Write row length
585    let row_len = row.len() as u16;
586    spill_file.write_all(&row_len.to_le_bytes())?;
587
588    // Write row values
589    buf.clear();
590    for v in row {
591        serialize_value(v, &mut buf)?;
592    }
593    spill_file.write_all(&buf)?;
594
595    Ok(())
596}
597
598/// Read a keyed row from a spill file
599fn read_keyed_row(
600    spill_file: &mut SpillFile,
601) -> io::Result<Option<(Vec<SqlValue>, Vec<SqlValue>)>> {
602    // Read key length
603    let mut len_buf = [0u8; 2];
604    match spill_file.read_exact(&mut len_buf) {
605        Ok(()) => {}
606        Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => return Ok(None),
607        Err(e) => return Err(e),
608    }
609    let key_len = u16::from_le_bytes(len_buf) as usize;
610
611    // Read key values
612    let mut key = Vec::with_capacity(key_len);
613    for _ in 0..key_len {
614        key.push(deserialize_value_from_spill(spill_file)?);
615    }
616
617    // Read row length
618    spill_file.read_exact(&mut len_buf)?;
619    let row_len = u16::from_le_bytes(len_buf) as usize;
620
621    // Read row values
622    let mut row = Vec::with_capacity(row_len);
623    for _ in 0..row_len {
624        row.push(deserialize_value_from_spill(spill_file)?);
625    }
626
627    Ok(Some((key, row)))
628}
629
630/// Read a single value from spill file (wrapper around deserialize_value)
631fn deserialize_value_from_spill(spill_file: &mut SpillFile) -> io::Result<SqlValue> {
632    // We need to wrap SpillFile to implement Read
633    struct SpillFileReader<'a>(&'a mut SpillFile);
634
635    impl Read for SpillFileReader<'_> {
636        fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
637            self.0.read(buf)
638        }
639    }
640
641    let mut reader = SpillFileReader(spill_file);
642    deserialize_value(&mut reader)
643}
644
645/// Hash a SqlValue for partitioning
646fn hash_sql_value<H: std::hash::Hasher>(value: &SqlValue, hasher: &mut H) {
647    use std::hash::Hash;
648    match value {
649        SqlValue::Null => 0u8.hash(hasher),
650        SqlValue::Integer(i) => {
651            1u8.hash(hasher);
652            i.hash(hasher);
653        }
654        SqlValue::Smallint(i) => {
655            2u8.hash(hasher);
656            i.hash(hasher);
657        }
658        SqlValue::Bigint(i) => {
659            3u8.hash(hasher);
660            i.hash(hasher);
661        }
662        SqlValue::Unsigned(u) => {
663            4u8.hash(hasher);
664            u.hash(hasher);
665        }
666        SqlValue::Numeric(d) => {
667            5u8.hash(hasher);
668            d.to_string().hash(hasher);
669        }
670        SqlValue::Float(f) => {
671            6u8.hash(hasher);
672            f.to_bits().hash(hasher);
673        }
674        SqlValue::Real(f) => {
675            7u8.hash(hasher);
676            f.to_bits().hash(hasher);
677        }
678        SqlValue::Double(f) => {
679            8u8.hash(hasher);
680            f.to_bits().hash(hasher);
681        }
682        SqlValue::Character(s) | SqlValue::Varchar(s) => {
683            9u8.hash(hasher);
684            s.hash(hasher);
685        }
686        SqlValue::Boolean(b) => {
687            10u8.hash(hasher);
688            b.hash(hasher);
689        }
690        SqlValue::Date(d) => {
691            11u8.hash(hasher);
692            d.hash(hasher);
693        }
694        SqlValue::Time(t) => {
695            12u8.hash(hasher);
696            t.hash(hasher);
697        }
698        SqlValue::Timestamp(ts) => {
699            13u8.hash(hasher);
700            ts.hash(hasher);
701        }
702        SqlValue::Interval(iv) => {
703            14u8.hash(hasher);
704            iv.value.hash(hasher);
705        }
706        SqlValue::Vector(v) => {
707            15u8.hash(hasher);
708            for f in v {
709                f.to_bits().hash(hasher);
710            }
711        }
712        SqlValue::Blob(b) => {
713            16u8.hash(hasher);
714            b.hash(hasher);
715        }
716    }
717}
718
719/// Estimate memory size of a row
720fn estimate_row_size(row: &[SqlValue]) -> usize {
721    let base_size = std::mem::size_of::<Vec<SqlValue>>() + std::mem::size_of_val(row);
722    let value_size: usize = row
723        .iter()
724        .map(|v| match v {
725            SqlValue::Varchar(s) | SqlValue::Character(s) => s.len(),
726            SqlValue::Vector(vec) => vec.len() * std::mem::size_of::<f32>(),
727            _ => 0,
728        })
729        .sum();
730    base_size + value_size
731}
732
733/// Iterator over hash join results
734pub struct HashJoinResultIterator {
735    results: std::vec::IntoIter<Vec<SqlValue>>,
736    #[allow(dead_code)]
737    _reservation: MemoryReservation,
738}
739
740impl Iterator for HashJoinResultIterator {
741    type Item = io::Result<Vec<SqlValue>>;
742
743    fn next(&mut self) -> Option<Self::Item> {
744        self.results.next().map(Ok)
745    }
746}
747
748#[cfg(test)]
749mod tests {
750    use super::*;
751
752    fn make_test_controller() -> Arc<MemoryController> {
753        Arc::new(MemoryController::with_budget(1024 * 1024)) // 1MB
754    }
755
756    #[test]
757    fn test_inner_join_basic() {
758        let controller = make_test_controller();
759        let config = ExternalHashJoinConfig { num_partitions: 4, max_rows_per_partition: 100 };
760
761        let mut join = ExternalHashJoin::new(
762            controller,
763            config,
764            vec![0], // Build key column
765            vec![0], // Probe key column
766            JoinType::Inner,
767        );
768
769        // Build side: (id, name)
770        join.add_build_row(&[SqlValue::Integer(1), SqlValue::Varchar("Alice".into())]).unwrap();
771        join.add_build_row(&[SqlValue::Integer(2), SqlValue::Varchar("Bob".into())]).unwrap();
772        join.add_build_row(&[SqlValue::Integer(3), SqlValue::Varchar("Charlie".into())]).unwrap();
773
774        // Probe side: (id, city)
775        join.add_probe_row(&[SqlValue::Integer(1), SqlValue::Varchar("NYC".into())]).unwrap();
776        join.add_probe_row(&[SqlValue::Integer(2), SqlValue::Varchar("LA".into())]).unwrap();
777        join.add_probe_row(&[SqlValue::Integer(4), SqlValue::Varchar("Chicago".into())]).unwrap(); // No match
778
779        let results: Vec<_> = join.finish().unwrap().map(|r| r.unwrap()).collect();
780
781        // Should have 2 results (id=1 and id=2 match)
782        assert_eq!(results.len(), 2);
783
784        // Verify Alice, NYC match
785        let alice_match = results.iter().find(|r| r[1] == SqlValue::Varchar("Alice".into()));
786        assert!(alice_match.is_some());
787        assert_eq!(alice_match.unwrap()[3], SqlValue::Varchar("NYC".into()));
788
789        // Verify Bob, LA match
790        let bob_match = results.iter().find(|r| r[1] == SqlValue::Varchar("Bob".into()));
791        assert!(bob_match.is_some());
792        assert_eq!(bob_match.unwrap()[3], SqlValue::Varchar("LA".into()));
793    }
794
795    #[test]
796    fn test_left_outer_join() {
797        let controller = make_test_controller();
798        let config = ExternalHashJoinConfig { num_partitions: 4, max_rows_per_partition: 100 };
799
800        let mut join =
801            ExternalHashJoin::new(controller, config, vec![0], vec![0], JoinType::LeftOuter);
802
803        // Build side
804        join.add_build_row(&[SqlValue::Integer(1), SqlValue::Varchar("A".into())]).unwrap();
805
806        // Probe side (left table in left outer)
807        join.add_probe_row(&[SqlValue::Integer(1), SqlValue::Varchar("X".into())]).unwrap();
808        join.add_probe_row(&[SqlValue::Integer(2), SqlValue::Varchar("Y".into())]).unwrap(); // No match
809
810        let results: Vec<_> = join.finish().unwrap().map(|r| r.unwrap()).collect();
811
812        // Should have 2 results
813        assert_eq!(results.len(), 2);
814
815        // One with match, one with NULLs
816        let matched = results.iter().filter(|r| r[0] != SqlValue::Null).count();
817        let unmatched = results.iter().filter(|r| r[0] == SqlValue::Null).count();
818        assert_eq!(matched, 1);
819        assert_eq!(unmatched, 1);
820    }
821
822    #[test]
823    fn test_multi_key_join() {
824        let controller = make_test_controller();
825        let config = ExternalHashJoinConfig { num_partitions: 4, max_rows_per_partition: 100 };
826
827        let mut join = ExternalHashJoin::new(
828            controller,
829            config,
830            vec![0, 1], // Two key columns
831            vec![0, 1],
832            JoinType::Inner,
833        );
834
835        // Build: (a, b, val)
836        join.add_build_row(&[
837            SqlValue::Integer(1),
838            SqlValue::Integer(10),
839            SqlValue::Varchar("X".into()),
840        ])
841        .unwrap();
842        join.add_build_row(&[
843            SqlValue::Integer(1),
844            SqlValue::Integer(20),
845            SqlValue::Varchar("Y".into()),
846        ])
847        .unwrap();
848
849        // Probe: (a, b, other)
850        join.add_probe_row(&[
851            SqlValue::Integer(1),
852            SqlValue::Integer(10),
853            SqlValue::Varchar("P".into()),
854        ])
855        .unwrap();
856        join.add_probe_row(&[
857            SqlValue::Integer(1),
858            SqlValue::Integer(30),
859            SqlValue::Varchar("Q".into()),
860        ])
861        .unwrap(); // No match
862
863        let results: Vec<_> = join.finish().unwrap().map(|r| r.unwrap()).collect();
864
865        // Only (1, 10) matches
866        assert_eq!(results.len(), 1);
867        assert_eq!(results[0][2], SqlValue::Varchar("X".into()));
868        assert_eq!(results[0][5], SqlValue::Varchar("P".into()));
869    }
870
871    #[test]
872    fn test_null_handling() {
873        let controller = make_test_controller();
874        let config = ExternalHashJoinConfig { num_partitions: 4, max_rows_per_partition: 100 };
875
876        let mut join = ExternalHashJoin::new(controller, config, vec![0], vec![0], JoinType::Inner);
877
878        // Build with NULL key
879        join.add_build_row(&[SqlValue::Null, SqlValue::Varchar("A".into())]).unwrap();
880        join.add_build_row(&[SqlValue::Integer(1), SqlValue::Varchar("B".into())]).unwrap();
881
882        // Probe with NULL key
883        join.add_probe_row(&[SqlValue::Null, SqlValue::Varchar("X".into())]).unwrap();
884        join.add_probe_row(&[SqlValue::Integer(1), SqlValue::Varchar("Y".into())]).unwrap();
885
886        let results: Vec<_> = join.finish().unwrap().map(|r| r.unwrap()).collect();
887
888        // NULLs should not match - only id=1 should match
889        assert_eq!(results.len(), 1);
890        assert_eq!(results[0][1], SqlValue::Varchar("B".into()));
891    }
892
893    #[test]
894    fn test_empty_inputs() {
895        let controller = make_test_controller();
896        let config = ExternalHashJoinConfig::default();
897
898        let join = ExternalHashJoin::new(controller, config, vec![0], vec![0], JoinType::Inner);
899
900        let results: Vec<_> = join.finish().unwrap().map(|r| r.unwrap()).collect();
901        assert!(results.is_empty());
902    }
903
904    #[test]
905    fn test_many_partitions() {
906        let controller = make_test_controller();
907        let config = ExternalHashJoinConfig { num_partitions: 16, max_rows_per_partition: 10 };
908
909        let mut join = ExternalHashJoin::new(controller, config, vec![0], vec![0], JoinType::Inner);
910
911        // Add many rows to test partitioning
912        for i in 0..100 {
913            join.add_build_row(&[SqlValue::Integer(i), SqlValue::Integer(i * 10)]).unwrap();
914            join.add_probe_row(&[SqlValue::Integer(i), SqlValue::Integer(i * 100)]).unwrap();
915        }
916
917        let results: Vec<_> = join.finish().unwrap().map(|r| r.unwrap()).collect();
918
919        // All 100 rows should match
920        assert_eq!(results.len(), 100);
921    }
922
923    #[test]
924    fn test_duplicate_keys() {
925        let controller = make_test_controller();
926        let config = ExternalHashJoinConfig { num_partitions: 4, max_rows_per_partition: 100 };
927
928        let mut join = ExternalHashJoin::new(controller, config, vec![0], vec![0], JoinType::Inner);
929
930        // Build with duplicate keys
931        join.add_build_row(&[SqlValue::Integer(1), SqlValue::Varchar("A".into())]).unwrap();
932        join.add_build_row(&[SqlValue::Integer(1), SqlValue::Varchar("B".into())]).unwrap();
933
934        // Probe with duplicate keys
935        join.add_probe_row(&[SqlValue::Integer(1), SqlValue::Varchar("X".into())]).unwrap();
936        join.add_probe_row(&[SqlValue::Integer(1), SqlValue::Varchar("Y".into())]).unwrap();
937
938        let results: Vec<_> = join.finish().unwrap().map(|r| r.unwrap()).collect();
939
940        // Should have 2x2 = 4 results (cartesian product of matching rows)
941        assert_eq!(results.len(), 4);
942    }
943}