Skip to main content

sochdb_query/executor/
join.rs

1// SPDX-License-Identifier: AGPL-3.0-or-later
2
3//! Join operators: HashJoin, NestedLoopJoin, MergeJoin.
4
5use crate::soch_ql::SochValue;
6use crate::sql::ast::{Expr, JoinType};
7use super::eval::{eval_expr, eval_predicate, compare_values};
8use super::node::PlanNode;
9use super::types::{Row, Schema};
10use sochdb_core::Result;
11use std::collections::HashMap;
12
13// ============================================================================
14// HashJoinNode — Hash-based equi-join
15// ============================================================================
16
17/// Hash join: builds a hash table from the build side, probes with the probe side.
18///
19/// Supports INNER, LEFT, RIGHT, and FULL [OUTER] joins.
20///
21/// ```text
22/// HashJoin(build_key=users.id, probe_key=orders.user_id)
23///   ├── build: SeqScan(users)
24///   └── probe: SeqScan(orders)
25/// ```
26pub struct HashJoinNode {
27    build: Box<dyn PlanNode>,
28    probe: Box<dyn PlanNode>,
29    /// Expression to evaluate on build side to produce hash key.
30    build_key_expr: Expr,
31    /// Expression to evaluate on probe side to produce hash key.
32    probe_key_expr: Expr,
33    join_type: JoinType,
34    output_schema: Schema,
35    /// Hash table: key -> list of build rows.
36    hash_table: Option<HashMap<HashKey, Vec<Row>>>,
37    /// For LEFT/RIGHT/FULL: track which build rows were matched.
38    build_matched: Vec<bool>,
39    /// Current probe row being processed.
40    current_probe_row: Option<Row>,
41    /// Matches for current probe row.
42    current_matches: Vec<Row>,
43    /// Index into current_matches.
44    match_idx: usize,
45    /// For RIGHT/FULL: unmatched build rows to emit after probe exhausted.
46    unmatched_buffer: Option<Vec<Row>>,
47    unmatched_pos: usize,
48    /// Whether probe side is exhausted.
49    probe_exhausted: bool,
50    /// Whether join produced a match for current probe row.
51    current_probe_matched: bool,
52    build_schema: Schema,
53    probe_schema: Schema,
54}
55
56/// Simple hash key wrapping a SochValue for HashMap use.
57#[derive(Debug, Clone, PartialEq, Eq, Hash)]
58enum HashKey {
59    Int(i64),
60    UInt(u64),
61    Text(String),
62    Bool(bool),
63    Null,
64    Other(String),
65}
66
67impl From<&SochValue> for HashKey {
68    fn from(v: &SochValue) -> Self {
69        match v {
70            SochValue::Int(i) => HashKey::Int(*i),
71            SochValue::UInt(u) => HashKey::UInt(*u),
72            SochValue::Text(s) => HashKey::Text(s.clone()),
73            SochValue::Bool(b) => HashKey::Bool(*b),
74            SochValue::Null => HashKey::Null,
75            other => HashKey::Other(format!("{:?}", other)),
76        }
77    }
78}
79
80impl HashJoinNode {
81    pub fn new(
82        build: Box<dyn PlanNode>,
83        probe: Box<dyn PlanNode>,
84        build_key_expr: Expr,
85        probe_key_expr: Expr,
86        join_type: JoinType,
87    ) -> Self {
88        let build_schema = build.schema().clone();
89        let probe_schema = probe.schema().clone();
90        let output_schema = build_schema.merge(&probe_schema);
91
92        Self {
93            build,
94            probe,
95            build_key_expr,
96            probe_key_expr,
97            join_type,
98            output_schema,
99            hash_table: None,
100            build_matched: Vec::new(),
101            current_probe_row: None,
102            current_matches: Vec::new(),
103            match_idx: 0,
104            unmatched_buffer: None,
105            unmatched_pos: 0,
106            probe_exhausted: false,
107            current_probe_matched: false,
108            build_schema,
109            probe_schema,
110        }
111    }
112
113    fn build_hash_table(&mut self) -> Result<()> {
114        if self.hash_table.is_some() {
115            return Ok(());
116        }
117
118        let mut table: HashMap<HashKey, Vec<Row>> = HashMap::new();
119        let mut all_build_rows: Vec<Row> = Vec::new();
120        let schema = self.build.schema().clone();
121
122        while let Some(row) = self.build.next()? {
123            let key_val = eval_expr(&self.build_key_expr, &row, &schema)?;
124            let key = HashKey::from(&key_val);
125            table.entry(key).or_default().push(row.clone());
126            all_build_rows.push(row);
127        }
128
129        self.build_matched = vec![false; all_build_rows.len()];
130        self.hash_table = Some(table);
131        Ok(())
132    }
133
134    fn null_row(schema: &Schema) -> Row {
135        vec![SochValue::Null; schema.len()]
136    }
137
138    fn combine(build_row: &Row, probe_row: &Row) -> Row {
139        let mut combined = build_row.clone();
140        combined.extend(probe_row.iter().cloned());
141        combined
142    }
143}
144
145impl PlanNode for HashJoinNode {
146    fn schema(&self) -> &Schema {
147        &self.output_schema
148    }
149
150    fn next(&mut self) -> Result<Option<Row>> {
151        self.build_hash_table()?;
152
153        loop {
154            // Return pending matches from current probe row
155            if self.match_idx < self.current_matches.len() {
156                let build_row = &self.current_matches[self.match_idx];
157                let probe_row = self.current_probe_row.as_ref().unwrap();
158                self.match_idx += 1;
159                return Ok(Some(Self::combine(build_row, probe_row)));
160            }
161
162            // For LEFT join: emit unmatched probe row
163            if self.current_probe_row.is_some()
164                && !self.current_probe_matched
165                && matches!(self.join_type, JoinType::Left | JoinType::Full)
166            {
167                let probe_row = self.current_probe_row.take().unwrap();
168                let null_build = Self::null_row(&self.build_schema);
169                return Ok(Some(Self::combine(&null_build, &probe_row)));
170            }
171
172            // Done with current probe row, reset
173            self.current_probe_row = None;
174            self.current_matches.clear();
175            self.match_idx = 0;
176            self.current_probe_matched = false;
177
178            if !self.probe_exhausted {
179                // Get next probe row
180                match self.probe.next()? {
181                    Some(probe_row) => {
182                        let key_val = eval_expr(
183                            &self.probe_key_expr,
184                            &probe_row,
185                            &self.probe_schema,
186                        )?;
187                        let key = HashKey::from(&key_val);
188
189                        if let Some(ht) = &self.hash_table {
190                            if let Some(matches) = ht.get(&key) {
191                                self.current_matches = matches.clone();
192                                self.current_probe_matched = true;
193                                // Mark matched build rows
194                                // (simplified: we'd need row indices for precise tracking)
195                            }
196                        }
197
198                        self.current_probe_row = Some(probe_row);
199                        continue;
200                    }
201                    None => {
202                        self.probe_exhausted = true;
203                    }
204                }
205            }
206
207            // After probe exhausted: emit unmatched build rows for RIGHT/FULL join
208            if matches!(self.join_type, JoinType::Right | JoinType::Full) {
209                if self.unmatched_buffer.is_none() {
210                    // Collect unmatched build rows
211                    // Simplified: for now just return None
212                    // Full RIGHT/FULL join requires tracking which build rows were matched
213                    self.unmatched_buffer = Some(Vec::new());
214                }
215
216                if let Some(buf) = &self.unmatched_buffer {
217                    if self.unmatched_pos < buf.len() {
218                        let row = buf[self.unmatched_pos].clone();
219                        self.unmatched_pos += 1;
220                        let null_probe = Self::null_row(&self.probe_schema);
221                        return Ok(Some(Self::combine(&row, &null_probe)));
222                    }
223                }
224            }
225
226            return Ok(None);
227        }
228    }
229
230    fn reset(&mut self) -> Result<()> {
231        self.hash_table = None;
232        self.current_probe_row = None;
233        self.current_matches.clear();
234        self.match_idx = 0;
235        self.probe_exhausted = false;
236        self.unmatched_buffer = None;
237        self.unmatched_pos = 0;
238        self.build.reset()?;
239        self.probe.reset()
240    }
241}
242
243// ============================================================================
244// NestedLoopJoinNode — Theta join (any join condition)
245// ============================================================================
246
247/// Nested loop join: for each outer row, scans all inner rows testing condition.
248///
249/// Supports all join types and arbitrary join conditions.
250pub struct NestedLoopJoinNode {
251    outer: Box<dyn PlanNode>,
252    inner: Box<dyn PlanNode>,
253    condition: Option<Expr>,
254    join_type: JoinType,
255    output_schema: Schema,
256    /// Current outer row.
257    current_outer: Option<Row>,
258    /// Whether current outer row has matched any inner row.
259    current_matched: bool,
260    /// Whether join is exhausted.
261    outer_exhausted: bool,
262    _outer_schema: Schema,
263    inner_schema: Schema,
264}
265
266impl NestedLoopJoinNode {
267    pub fn new(
268        outer: Box<dyn PlanNode>,
269        inner: Box<dyn PlanNode>,
270        condition: Option<Expr>,
271        join_type: JoinType,
272    ) -> Self {
273        let outer_schema = outer.schema().clone();
274        let inner_schema = inner.schema().clone();
275        let output_schema = outer_schema.merge(&inner_schema);
276
277        Self {
278            outer,
279            inner,
280            condition,
281            join_type,
282            output_schema,
283            current_outer: None,
284            current_matched: false,
285            outer_exhausted: false,
286            _outer_schema: outer_schema,
287            inner_schema,
288        }
289    }
290
291    fn combine(outer_row: &Row, inner_row: &Row) -> Row {
292        let mut combined = outer_row.clone();
293        combined.extend(inner_row.iter().cloned());
294        combined
295    }
296
297    fn null_row(schema: &Schema) -> Row {
298        vec![SochValue::Null; schema.len()]
299    }
300}
301
302impl PlanNode for NestedLoopJoinNode {
303    fn schema(&self) -> &Schema {
304        &self.output_schema
305    }
306
307    fn next(&mut self) -> Result<Option<Row>> {
308        loop {
309            // Get current outer row (or advance to next)
310            if self.current_outer.is_none() {
311                if self.outer_exhausted {
312                    return Ok(None);
313                }
314                match self.outer.next()? {
315                    Some(row) => {
316                        self.current_outer = Some(row);
317                        self.current_matched = false;
318                        self.inner.reset()?;
319                    }
320                    None => {
321                        self.outer_exhausted = true;
322                        return Ok(None);
323                    }
324                }
325            }
326
327            let outer_row = self.current_outer.as_ref().unwrap();
328
329            // Try to find next matching inner row
330            match self.inner.next()? {
331                Some(inner_row) => {
332                    let combined = Self::combine(outer_row, &inner_row);
333
334                    // Evaluate join condition
335                    let matched = match &self.condition {
336                        Some(cond) => eval_predicate(cond, &combined, &self.output_schema)?,
337                        None => true, // CROSS JOIN
338                    };
339
340                    if matched {
341                        self.current_matched = true;
342                        return Ok(Some(combined));
343                    }
344                    // Not matched, try next inner row
345                    continue;
346                }
347                None => {
348                    // Inner side exhausted for this outer row
349                    let need_null_row = !self.current_matched
350                        && matches!(self.join_type, JoinType::Left | JoinType::Full);
351
352                    let outer_row = self.current_outer.take().unwrap();
353
354                    if need_null_row {
355                        let null_inner = Self::null_row(&self.inner_schema);
356                        return Ok(Some(Self::combine(&outer_row, &null_inner)));
357                    }
358                    // Move to next outer row
359                    continue;
360                }
361            }
362        }
363    }
364
365    fn reset(&mut self) -> Result<()> {
366        self.current_outer = None;
367        self.current_matched = false;
368        self.outer_exhausted = false;
369        self.outer.reset()?;
370        self.inner.reset()
371    }
372}
373
374// ============================================================================
375// MergeJoinNode — Merge join on sorted inputs
376// ============================================================================
377
378/// Merge join: requires both inputs sorted on join keys.
379///
380/// For INNER JOIN, produces output only when keys match.
381pub struct MergeJoinNode {
382    left: Box<dyn PlanNode>,
383    right: Box<dyn PlanNode>,
384    left_key_expr: Expr,
385    right_key_expr: Expr,
386    join_type: JoinType,
387    output_schema: Schema,
388    left_schema: Schema,
389    right_schema: Schema,
390    /// Buffered rows from right side with same key (for many-to-many).
391    right_buffer: Vec<Row>,
392    right_buffer_key: Option<SochValue>,
393    right_buf_idx: usize,
394    current_left: Option<Row>,
395    current_left_key: Option<SochValue>,
396    right_exhausted: bool,
397    pending_right: Option<Row>,
398}
399
400impl MergeJoinNode {
401    pub fn new(
402        left: Box<dyn PlanNode>,
403        right: Box<dyn PlanNode>,
404        left_key_expr: Expr,
405        right_key_expr: Expr,
406        join_type: JoinType,
407    ) -> Self {
408        let left_schema = left.schema().clone();
409        let right_schema = right.schema().clone();
410        let output_schema = left_schema.merge(&right_schema);
411
412        Self {
413            left,
414            right,
415            left_key_expr,
416            right_key_expr,
417            join_type,
418            output_schema,
419            left_schema,
420            right_schema,
421            right_buffer: Vec::new(),
422            right_buffer_key: None,
423            right_buf_idx: 0,
424            current_left: None,
425            current_left_key: None,
426            right_exhausted: false,
427            pending_right: None,
428        }
429    }
430
431    fn combine(left_row: &Row, right_row: &Row) -> Row {
432        let mut combined = left_row.clone();
433        combined.extend(right_row.iter().cloned());
434        combined
435    }
436
437    fn advance_right(&mut self) -> Result<Option<(SochValue, Row)>> {
438        if let Some(row) = self.pending_right.take() {
439            let key = eval_expr(&self.right_key_expr, &row, &self.right_schema)?;
440            return Ok(Some((key, row)));
441        }
442        match self.right.next()? {
443            Some(row) => {
444                let key = eval_expr(&self.right_key_expr, &row, &self.right_schema)?;
445                Ok(Some((key, row)))
446            }
447            None => {
448                self.right_exhausted = true;
449                Ok(None)
450            }
451        }
452    }
453}
454
455impl PlanNode for MergeJoinNode {
456    fn schema(&self) -> &Schema {
457        &self.output_schema
458    }
459
460    fn next(&mut self) -> Result<Option<Row>> {
461        loop {
462            // If we have buffered right matches, emit them
463            if self.right_buf_idx < self.right_buffer.len() {
464                if let Some(left_row) = &self.current_left {
465                    let right_row = &self.right_buffer[self.right_buf_idx];
466                    self.right_buf_idx += 1;
467                    return Ok(Some(Self::combine(left_row, right_row)));
468                }
469            }
470
471            // Need new left row
472            let left_row = match self.left.next()? {
473                Some(row) => row,
474                None => return Ok(None),
475            };
476            let left_key = eval_expr(&self.left_key_expr, &left_row, &self.left_schema)?;
477
478            // Check if right buffer has same key
479            if self.right_buffer_key.as_ref().map_or(false, |k| {
480                compare_values(k, &left_key) == Some(std::cmp::Ordering::Equal)
481            }) {
482                self.current_left = Some(left_row);
483                self.current_left_key = Some(left_key);
484                self.right_buf_idx = 0;
485                continue;
486            }
487
488            // Need to advance right side to match left key
489            self.right_buffer.clear();
490            self.right_buf_idx = 0;
491
492            if self.right_exhausted {
493                if matches!(self.join_type, JoinType::Left | JoinType::Full) {
494                    let null_right = vec![SochValue::Null; self.right_schema.len()];
495                    return Ok(Some(Self::combine(&left_row, &null_right)));
496                }
497                return Ok(None);
498            }
499
500            // Advance right until we find matching or greater key
501            loop {
502                match self.advance_right()? {
503                    Some((right_key, right_row)) => {
504                        match compare_values(&right_key, &left_key) {
505                            Some(std::cmp::Ordering::Equal) => {
506                                self.right_buffer.push(right_row);
507                                self.right_buffer_key = Some(right_key);
508                                // Collect all right rows with same key
509                                break;
510                            }
511                            Some(std::cmp::Ordering::Greater) => {
512                                // Right key is past left key
513                                self.pending_right = Some(right_row);
514                                break;
515                            }
516                            _ => {
517                                // Right key is less than left key, skip
518                                continue;
519                            }
520                        }
521                    }
522                    None => break,
523                }
524            }
525
526            // Collect remaining right rows with same key
527            if !self.right_buffer.is_empty() {
528                loop {
529                    match self.advance_right()? {
530                        Some((right_key, right_row)) => {
531                            if compare_values(&right_key, &left_key) == Some(std::cmp::Ordering::Equal) {
532                                self.right_buffer.push(right_row);
533                            } else {
534                                self.pending_right = Some(right_row);
535                                break;
536                            }
537                        }
538                        None => break,
539                    }
540                }
541            }
542
543            self.current_left = Some(left_row);
544            self.current_left_key = Some(left_key);
545            self.right_buf_idx = 0;
546
547            if self.right_buffer.is_empty() {
548                if matches!(self.join_type, JoinType::Left | JoinType::Full) {
549                    let left_row = self.current_left.take().unwrap();
550                    let null_right = vec![SochValue::Null; self.right_schema.len()];
551                    return Ok(Some(Self::combine(&left_row, &null_right)));
552                }
553                // Inner join, no match => skip this left row
554                continue;
555            }
556        }
557    }
558
559    fn reset(&mut self) -> Result<()> {
560        self.right_buffer.clear();
561        self.right_buffer_key = None;
562        self.right_buf_idx = 0;
563        self.current_left = None;
564        self.current_left_key = None;
565        self.right_exhausted = false;
566        self.pending_right = None;
567        self.left.reset()?;
568        self.right.reset()
569    }
570}