Skip to main content

shape_runtime/
join_executor.rs

1//! JOIN execution engine
2//!
3//! Executes SQL-style JOINs between data sources:
4//! - INNER JOIN: Only matching rows
5//! - LEFT JOIN: All left rows, matching right rows (or nulls)
6//! - RIGHT JOIN: Matching left rows (or nulls), all right rows
7//! - FULL JOIN: All rows from both sides
8//! - CROSS JOIN: Cartesian product
9//! - TEMPORAL JOIN: Time-based matching within a tolerance window
10
11use crate::context::ExecutionContext;
12use shape_ast::ast::{JoinClause, JoinCondition, JoinType};
13use shape_ast::error::Result;
14use shape_value::ValueWord;
15use std::collections::HashMap;
16
17/// Execute JOINs between data sources
18pub struct JoinExecutor;
19
20impl JoinExecutor {
21    /// Execute a join between left and right datasets
22    pub fn execute(
23        left: Vec<HashMap<String, ValueWord>>,
24        right: Vec<HashMap<String, ValueWord>>,
25        join: &JoinClause,
26        ctx: &mut ExecutionContext,
27    ) -> Result<Vec<HashMap<String, ValueWord>>> {
28        Self::execute_with_evaluator(left, right, join, None, ctx)
29    }
30
31    /// Execute a join with an optional expression evaluator for ON clause evaluation
32    pub fn execute_with_evaluator(
33        left: Vec<HashMap<String, ValueWord>>,
34        right: Vec<HashMap<String, ValueWord>>,
35        join: &JoinClause,
36        evaluator: Option<&dyn crate::engine::ExpressionEvaluator>,
37        ctx: &mut ExecutionContext,
38    ) -> Result<Vec<HashMap<String, ValueWord>>> {
39        match join.join_type {
40            JoinType::Inner => Self::inner_join(left, right, &join.condition, evaluator, ctx),
41            JoinType::Left => Self::left_join(left, right, &join.condition, evaluator, ctx),
42            JoinType::Right => Self::right_join(left, right, &join.condition, evaluator, ctx),
43            JoinType::Full => Self::full_join(left, right, &join.condition, evaluator, ctx),
44            JoinType::Cross => Self::cross_join(left, right),
45        }
46    }
47
48    /// Execute INNER JOIN: only rows that match the condition
49    fn inner_join(
50        left: Vec<HashMap<String, ValueWord>>,
51        right: Vec<HashMap<String, ValueWord>>,
52        condition: &JoinCondition,
53        evaluator: Option<&dyn crate::engine::ExpressionEvaluator>,
54        ctx: &mut ExecutionContext,
55    ) -> Result<Vec<HashMap<String, ValueWord>>> {
56        let mut results = Vec::new();
57
58        for l_row in &left {
59            for r_row in &right {
60                if Self::matches_condition(l_row, r_row, condition, evaluator, ctx)? {
61                    let merged = Self::merge_rows(l_row, r_row, "right");
62                    results.push(merged);
63                }
64            }
65        }
66
67        Ok(results)
68    }
69
70    /// Execute LEFT JOIN: all left rows, with matching right rows or nulls
71    fn left_join(
72        left: Vec<HashMap<String, ValueWord>>,
73        right: Vec<HashMap<String, ValueWord>>,
74        condition: &JoinCondition,
75        evaluator: Option<&dyn crate::engine::ExpressionEvaluator>,
76        ctx: &mut ExecutionContext,
77    ) -> Result<Vec<HashMap<String, ValueWord>>> {
78        let mut results = Vec::new();
79
80        for l_row in &left {
81            let mut matched = false;
82
83            for r_row in &right {
84                if Self::matches_condition(l_row, r_row, condition, evaluator, ctx)? {
85                    let merged = Self::merge_rows(l_row, r_row, "right");
86                    results.push(merged);
87                    matched = true;
88                }
89            }
90
91            // No match - include left row with nulls for right columns
92            if !matched {
93                let merged = Self::merge_with_nulls(l_row, &right, "right");
94                results.push(merged);
95            }
96        }
97
98        Ok(results)
99    }
100
101    /// Execute RIGHT JOIN: matching left rows or nulls, with all right rows
102    fn right_join(
103        left: Vec<HashMap<String, ValueWord>>,
104        right: Vec<HashMap<String, ValueWord>>,
105        condition: &JoinCondition,
106        evaluator: Option<&dyn crate::engine::ExpressionEvaluator>,
107        ctx: &mut ExecutionContext,
108    ) -> Result<Vec<HashMap<String, ValueWord>>> {
109        let mut results = Vec::new();
110
111        for r_row in &right {
112            let mut matched = false;
113
114            for l_row in &left {
115                if Self::matches_condition(l_row, r_row, condition, evaluator, ctx)? {
116                    let merged = Self::merge_rows(l_row, r_row, "right");
117                    results.push(merged);
118                    matched = true;
119                }
120            }
121
122            // No match - include right row with nulls for left columns
123            if !matched {
124                let merged = Self::merge_with_nulls_left(&left, r_row, "right");
125                results.push(merged);
126            }
127        }
128
129        Ok(results)
130    }
131
132    /// Execute FULL JOIN: all rows from both sides
133    fn full_join(
134        left: Vec<HashMap<String, ValueWord>>,
135        right: Vec<HashMap<String, ValueWord>>,
136        condition: &JoinCondition,
137        evaluator: Option<&dyn crate::engine::ExpressionEvaluator>,
138        ctx: &mut ExecutionContext,
139    ) -> Result<Vec<HashMap<String, ValueWord>>> {
140        let mut results = Vec::new();
141        let mut right_matched = vec![false; right.len()];
142
143        // Left outer join part
144        for l_row in &left {
145            let mut matched = false;
146
147            for (r_idx, r_row) in right.iter().enumerate() {
148                if Self::matches_condition(l_row, r_row, condition, evaluator, ctx)? {
149                    let merged = Self::merge_rows(l_row, r_row, "right");
150                    results.push(merged);
151                    matched = true;
152                    right_matched[r_idx] = true;
153                }
154            }
155
156            if !matched {
157                let merged = Self::merge_with_nulls(l_row, &right, "right");
158                results.push(merged);
159            }
160        }
161
162        // Add unmatched right rows
163        for (r_idx, r_row) in right.iter().enumerate() {
164            if !right_matched[r_idx] {
165                let merged = Self::merge_with_nulls_left(&left, r_row, "right");
166                results.push(merged);
167            }
168        }
169
170        Ok(results)
171    }
172
173    /// Execute CROSS JOIN: Cartesian product
174    fn cross_join(
175        left: Vec<HashMap<String, ValueWord>>,
176        right: Vec<HashMap<String, ValueWord>>,
177    ) -> Result<Vec<HashMap<String, ValueWord>>> {
178        let mut results = Vec::new();
179
180        for l_row in &left {
181            for r_row in &right {
182                let merged = Self::merge_rows(l_row, r_row, "right");
183                results.push(merged);
184            }
185        }
186
187        Ok(results)
188    }
189
190    /// Check if two rows match the join condition
191    fn matches_condition(
192        left: &HashMap<String, ValueWord>,
193        right: &HashMap<String, ValueWord>,
194        condition: &JoinCondition,
195        evaluator: Option<&dyn crate::engine::ExpressionEvaluator>,
196        ctx: &mut ExecutionContext,
197    ) -> Result<bool> {
198        match condition {
199            JoinCondition::On(expr) => {
200                // Set up context with both row values
201                ctx.push_scope();
202
203                // Add left row values
204                for (k, v) in left {
205                    let _ = ctx.set_variable_nb(k, v.clone());
206                }
207
208                // Add right row values with prefix
209                for (k, v) in right {
210                    let _ = ctx.set_variable_nb(&format!("right.{}", k), v.clone());
211                }
212
213                let result = if let Some(eval) = evaluator {
214                    // ExpressionEvaluator returns ValueWord, convert to ValueWord for inspection
215                    let vm_result = eval
216                        .eval_expr(expr, ctx)
217                        .unwrap_or(ValueWord::from_bool(false));
218                    vm_result
219                } else {
220                    ValueWord::from_bool(true) // Fallback: match all if no evaluator
221                };
222                ctx.pop_scope();
223
224                if let Some(b) = result.as_bool() {
225                    Ok(b)
226                } else if let Some(n) = result.as_f64() {
227                    Ok(n != 0.0 && !n.is_nan())
228                } else {
229                    Ok(false)
230                }
231            }
232
233            JoinCondition::Using(columns) => {
234                // Match on specified columns
235                for col in columns {
236                    let l_val = left.get(col);
237                    let r_val = right.get(col);
238
239                    match (l_val, r_val) {
240                        (Some(a), Some(b)) if !nb_values_equal(a, b) => return Ok(false),
241                        (None, None) => {} // Both null matches
242                        (None, Some(_)) | (Some(_), None) => return Ok(false),
243                        _ => {}
244                    }
245                }
246                Ok(true)
247            }
248
249            JoinCondition::Temporal {
250                left_time,
251                right_time,
252                within,
253            } => {
254                let l_ts = left.get(left_time).and_then(extract_timestamp_nb);
255                let r_ts = right.get(right_time).and_then(extract_timestamp_nb);
256
257                if let (Some(l), Some(r)) = (l_ts, r_ts) {
258                    let diff_ms = (l - r).abs();
259                    let threshold_ms = within.to_seconds() as f64 * 1000.0;
260                    Ok(diff_ms <= threshold_ms)
261                } else {
262                    Ok(false)
263                }
264            }
265
266            JoinCondition::Natural => {
267                // Match on all common column names
268                for (k, l_val) in left {
269                    if let Some(r_val) = right.get(k) {
270                        if !nb_values_equal(l_val, r_val) {
271                            return Ok(false);
272                        }
273                    }
274                }
275                Ok(true)
276            }
277        }
278    }
279
280    /// Merge two rows, prefixing right columns
281    fn merge_rows(
282        left: &HashMap<String, ValueWord>,
283        right: &HashMap<String, ValueWord>,
284        right_prefix: &str,
285    ) -> HashMap<String, ValueWord> {
286        let mut merged = left.clone();
287
288        for (k, v) in right {
289            merged.insert(format!("{}.{}", right_prefix, k), v.clone());
290        }
291
292        merged
293    }
294
295    /// Merge left row with null values for right columns
296    fn merge_with_nulls(
297        left: &HashMap<String, ValueWord>,
298        right_sample: &[HashMap<String, ValueWord>],
299        right_prefix: &str,
300    ) -> HashMap<String, ValueWord> {
301        let mut merged = left.clone();
302
303        // Get column names from first right row (if any)
304        if let Some(first_right) = right_sample.first() {
305            for k in first_right.keys() {
306                merged.insert(format!("{}.{}", right_prefix, k), ValueWord::none());
307            }
308        }
309
310        merged
311    }
312
313    /// Merge null values for left columns with right row
314    fn merge_with_nulls_left(
315        left_sample: &[HashMap<String, ValueWord>],
316        right: &HashMap<String, ValueWord>,
317        right_prefix: &str,
318    ) -> HashMap<String, ValueWord> {
319        let mut merged = HashMap::new();
320
321        // Get column names from first left row (if any)
322        if let Some(first_left) = left_sample.first() {
323            for k in first_left.keys() {
324                merged.insert(k.clone(), ValueWord::none());
325            }
326        }
327
328        // Add right columns
329        for (k, v) in right {
330            merged.insert(format!("{}.{}", right_prefix, k), v.clone());
331        }
332
333        merged
334    }
335}
336
337/// Check if two ValueWord values are equal
338fn nb_values_equal(a: &ValueWord, b: &ValueWord) -> bool {
339    use shape_value::NanTag;
340    match (a.tag(), b.tag()) {
341        (NanTag::F64, NanTag::F64)
342        | (NanTag::I48, NanTag::I48)
343        | (NanTag::F64, NanTag::I48)
344        | (NanTag::I48, NanTag::F64) => {
345            if let (Some(an), Some(bn)) = (a.as_f64(), b.as_f64()) {
346                if an.is_nan() && bn.is_nan() {
347                    true
348                } else {
349                    (an - bn).abs() < f64::EPSILON
350                }
351            } else {
352                false
353            }
354        }
355        (NanTag::Heap, NanTag::Heap) => {
356            if let (Some(sa), Some(sb)) = (a.as_str(), b.as_str()) {
357                sa == sb
358            } else {
359                false
360            }
361        }
362        (NanTag::Bool, NanTag::Bool) => a.as_bool() == b.as_bool(),
363        (NanTag::None, NanTag::None) => true,
364        _ => {
365            // For Time values, fall back
366            if let (Some(ta), Some(tb)) = (a.as_time(), b.as_time()) {
367                ta == tb
368            } else {
369                false
370            }
371        }
372    }
373}
374
375/// Extract timestamp as milliseconds from a ValueWord value
376fn extract_timestamp_nb(v: &ValueWord) -> Option<f64> {
377    if let Some(n) = v.as_f64() {
378        Some(n)
379    } else if let Some(t) = v.as_time() {
380        Some(t.timestamp_millis() as f64)
381    } else {
382        None
383    }
384}
385
386#[cfg(test)]
387mod tests {
388    use super::*;
389    use crate::context::ExecutionContext;
390    use shape_ast::ast::JoinSource;
391
392    fn make_rows(data: Vec<Vec<(&str, ValueWord)>>) -> Vec<HashMap<String, ValueWord>> {
393        data.into_iter()
394            .map(|row| row.into_iter().map(|(k, v)| (k.to_string(), v)).collect())
395            .collect()
396    }
397
398    #[test]
399    fn test_inner_join_using() {
400        let mut ctx = ExecutionContext::new_empty();
401
402        let left = make_rows(vec![
403            vec![
404                ("id", ValueWord::from_f64(1.0)),
405                (
406                    "name",
407                    ValueWord::from_string(std::sync::Arc::new("A".to_string())),
408                ),
409            ],
410            vec![
411                ("id", ValueWord::from_f64(2.0)),
412                (
413                    "name",
414                    ValueWord::from_string(std::sync::Arc::new("B".to_string())),
415                ),
416            ],
417            vec![
418                ("id", ValueWord::from_f64(3.0)),
419                (
420                    "name",
421                    ValueWord::from_string(std::sync::Arc::new("C".to_string())),
422                ),
423            ],
424        ]);
425
426        let right = make_rows(vec![
427            vec![
428                ("id", ValueWord::from_f64(1.0)),
429                ("value", ValueWord::from_f64(100.0)),
430            ],
431            vec![
432                ("id", ValueWord::from_f64(3.0)),
433                ("value", ValueWord::from_f64(300.0)),
434            ],
435        ]);
436
437        let join = JoinClause {
438            join_type: JoinType::Inner,
439            right: JoinSource::Named("test".to_string()),
440            condition: JoinCondition::Using(vec!["id".to_string()]),
441        };
442
443        let result = JoinExecutor::execute(left, right, &join, &mut ctx).unwrap();
444
445        // Should have 2 matching rows (id 1 and 3)
446        assert_eq!(result.len(), 2);
447
448        // Check first match
449        assert_eq!(result[0].get("id").map(|v| v.as_f64()), Some(Some(1.0)));
450        assert_eq!(result[0].get("name").and_then(|v| v.as_str()), Some("A"));
451        assert_eq!(
452            result[0].get("right.value").map(|v| v.as_f64()),
453            Some(Some(100.0))
454        );
455    }
456
457    #[test]
458    fn test_left_join() {
459        let mut ctx = ExecutionContext::new_empty();
460
461        let left = make_rows(vec![
462            vec![("id", ValueWord::from_f64(1.0))],
463            vec![("id", ValueWord::from_f64(2.0))],
464        ]);
465
466        let right = make_rows(vec![vec![
467            ("id", ValueWord::from_f64(1.0)),
468            ("val", ValueWord::from_f64(10.0)),
469        ]]);
470
471        let join = JoinClause {
472            join_type: JoinType::Left,
473            right: JoinSource::Named("test".to_string()),
474            condition: JoinCondition::Using(vec!["id".to_string()]),
475        };
476
477        let result = JoinExecutor::execute(left, right, &join, &mut ctx).unwrap();
478
479        // Should have 2 rows (all left rows)
480        assert_eq!(result.len(), 2);
481
482        // First row has match
483        assert_eq!(
484            result[0].get("right.val").map(|v| v.as_f64()),
485            Some(Some(10.0))
486        );
487
488        // Second row has null
489        assert!(
490            result[1]
491                .get("right.val")
492                .map(|v| v.is_none())
493                .unwrap_or(false)
494        );
495    }
496
497    #[test]
498    fn test_cross_join() {
499        let left = make_rows(vec![
500            vec![("a", ValueWord::from_f64(1.0))],
501            vec![("a", ValueWord::from_f64(2.0))],
502        ]);
503
504        let right = make_rows(vec![
505            vec![("b", ValueWord::from_f64(10.0))],
506            vec![("b", ValueWord::from_f64(20.0))],
507        ]);
508
509        let result = JoinExecutor::cross_join(left, right).unwrap();
510
511        // Should have 2 * 2 = 4 rows
512        assert_eq!(result.len(), 4);
513    }
514}