1use crate::context::ExecutionContext;
12use shape_ast::ast::{JoinClause, JoinCondition, JoinType};
13use shape_ast::error::Result;
14use shape_value::ValueWord;
15use std::collections::HashMap;
16
17pub struct JoinExecutor;
19
20impl JoinExecutor {
21 pub fn execute(
23 left: Vec<HashMap<String, ValueWord>>,
24 right: Vec<HashMap<String, ValueWord>>,
25 join: &JoinClause,
26 ctx: &mut ExecutionContext,
27 ) -> Result<Vec<HashMap<String, ValueWord>>> {
28 Self::execute_with_evaluator(left, right, join, None, ctx)
29 }
30
31 pub fn execute_with_evaluator(
33 left: Vec<HashMap<String, ValueWord>>,
34 right: Vec<HashMap<String, ValueWord>>,
35 join: &JoinClause,
36 evaluator: Option<&dyn crate::engine::ExpressionEvaluator>,
37 ctx: &mut ExecutionContext,
38 ) -> Result<Vec<HashMap<String, ValueWord>>> {
39 match join.join_type {
40 JoinType::Inner => Self::inner_join(left, right, &join.condition, evaluator, ctx),
41 JoinType::Left => Self::left_join(left, right, &join.condition, evaluator, ctx),
42 JoinType::Right => Self::right_join(left, right, &join.condition, evaluator, ctx),
43 JoinType::Full => Self::full_join(left, right, &join.condition, evaluator, ctx),
44 JoinType::Cross => Self::cross_join(left, right),
45 }
46 }
47
48 fn inner_join(
50 left: Vec<HashMap<String, ValueWord>>,
51 right: Vec<HashMap<String, ValueWord>>,
52 condition: &JoinCondition,
53 evaluator: Option<&dyn crate::engine::ExpressionEvaluator>,
54 ctx: &mut ExecutionContext,
55 ) -> Result<Vec<HashMap<String, ValueWord>>> {
56 let mut results = Vec::new();
57
58 for l_row in &left {
59 for r_row in &right {
60 if Self::matches_condition(l_row, r_row, condition, evaluator, ctx)? {
61 let merged = Self::merge_rows(l_row, r_row, "right");
62 results.push(merged);
63 }
64 }
65 }
66
67 Ok(results)
68 }
69
70 fn left_join(
72 left: Vec<HashMap<String, ValueWord>>,
73 right: Vec<HashMap<String, ValueWord>>,
74 condition: &JoinCondition,
75 evaluator: Option<&dyn crate::engine::ExpressionEvaluator>,
76 ctx: &mut ExecutionContext,
77 ) -> Result<Vec<HashMap<String, ValueWord>>> {
78 let mut results = Vec::new();
79
80 for l_row in &left {
81 let mut matched = false;
82
83 for r_row in &right {
84 if Self::matches_condition(l_row, r_row, condition, evaluator, ctx)? {
85 let merged = Self::merge_rows(l_row, r_row, "right");
86 results.push(merged);
87 matched = true;
88 }
89 }
90
91 if !matched {
93 let merged = Self::merge_with_nulls(l_row, &right, "right");
94 results.push(merged);
95 }
96 }
97
98 Ok(results)
99 }
100
101 fn right_join(
103 left: Vec<HashMap<String, ValueWord>>,
104 right: Vec<HashMap<String, ValueWord>>,
105 condition: &JoinCondition,
106 evaluator: Option<&dyn crate::engine::ExpressionEvaluator>,
107 ctx: &mut ExecutionContext,
108 ) -> Result<Vec<HashMap<String, ValueWord>>> {
109 let mut results = Vec::new();
110
111 for r_row in &right {
112 let mut matched = false;
113
114 for l_row in &left {
115 if Self::matches_condition(l_row, r_row, condition, evaluator, ctx)? {
116 let merged = Self::merge_rows(l_row, r_row, "right");
117 results.push(merged);
118 matched = true;
119 }
120 }
121
122 if !matched {
124 let merged = Self::merge_with_nulls_left(&left, r_row, "right");
125 results.push(merged);
126 }
127 }
128
129 Ok(results)
130 }
131
132 fn full_join(
134 left: Vec<HashMap<String, ValueWord>>,
135 right: Vec<HashMap<String, ValueWord>>,
136 condition: &JoinCondition,
137 evaluator: Option<&dyn crate::engine::ExpressionEvaluator>,
138 ctx: &mut ExecutionContext,
139 ) -> Result<Vec<HashMap<String, ValueWord>>> {
140 let mut results = Vec::new();
141 let mut right_matched = vec![false; right.len()];
142
143 for l_row in &left {
145 let mut matched = false;
146
147 for (r_idx, r_row) in right.iter().enumerate() {
148 if Self::matches_condition(l_row, r_row, condition, evaluator, ctx)? {
149 let merged = Self::merge_rows(l_row, r_row, "right");
150 results.push(merged);
151 matched = true;
152 right_matched[r_idx] = true;
153 }
154 }
155
156 if !matched {
157 let merged = Self::merge_with_nulls(l_row, &right, "right");
158 results.push(merged);
159 }
160 }
161
162 for (r_idx, r_row) in right.iter().enumerate() {
164 if !right_matched[r_idx] {
165 let merged = Self::merge_with_nulls_left(&left, r_row, "right");
166 results.push(merged);
167 }
168 }
169
170 Ok(results)
171 }
172
173 fn cross_join(
175 left: Vec<HashMap<String, ValueWord>>,
176 right: Vec<HashMap<String, ValueWord>>,
177 ) -> Result<Vec<HashMap<String, ValueWord>>> {
178 let mut results = Vec::new();
179
180 for l_row in &left {
181 for r_row in &right {
182 let merged = Self::merge_rows(l_row, r_row, "right");
183 results.push(merged);
184 }
185 }
186
187 Ok(results)
188 }
189
190 fn matches_condition(
192 left: &HashMap<String, ValueWord>,
193 right: &HashMap<String, ValueWord>,
194 condition: &JoinCondition,
195 evaluator: Option<&dyn crate::engine::ExpressionEvaluator>,
196 ctx: &mut ExecutionContext,
197 ) -> Result<bool> {
198 match condition {
199 JoinCondition::On(expr) => {
200 ctx.push_scope();
202
203 for (k, v) in left {
205 let _ = ctx.set_variable_nb(k, v.clone());
206 }
207
208 for (k, v) in right {
210 let _ = ctx.set_variable_nb(&format!("right.{}", k), v.clone());
211 }
212
213 let result = if let Some(eval) = evaluator {
214 let vm_result = eval
216 .eval_expr(expr, ctx)
217 .unwrap_or(ValueWord::from_bool(false));
218 vm_result
219 } else {
220 ValueWord::from_bool(true) };
222 ctx.pop_scope();
223
224 if let Some(b) = result.as_bool() {
225 Ok(b)
226 } else if let Some(n) = result.as_f64() {
227 Ok(n != 0.0 && !n.is_nan())
228 } else {
229 Ok(false)
230 }
231 }
232
233 JoinCondition::Using(columns) => {
234 for col in columns {
236 let l_val = left.get(col);
237 let r_val = right.get(col);
238
239 match (l_val, r_val) {
240 (Some(a), Some(b)) if !nb_values_equal(a, b) => return Ok(false),
241 (None, None) => {} (None, Some(_)) | (Some(_), None) => return Ok(false),
243 _ => {}
244 }
245 }
246 Ok(true)
247 }
248
249 JoinCondition::Temporal {
250 left_time,
251 right_time,
252 within,
253 } => {
254 let l_ts = left.get(left_time).and_then(extract_timestamp_nb);
255 let r_ts = right.get(right_time).and_then(extract_timestamp_nb);
256
257 if let (Some(l), Some(r)) = (l_ts, r_ts) {
258 let diff_ms = (l - r).abs();
259 let threshold_ms = within.to_seconds() as f64 * 1000.0;
260 Ok(diff_ms <= threshold_ms)
261 } else {
262 Ok(false)
263 }
264 }
265
266 JoinCondition::Natural => {
267 for (k, l_val) in left {
269 if let Some(r_val) = right.get(k) {
270 if !nb_values_equal(l_val, r_val) {
271 return Ok(false);
272 }
273 }
274 }
275 Ok(true)
276 }
277 }
278 }
279
280 fn merge_rows(
282 left: &HashMap<String, ValueWord>,
283 right: &HashMap<String, ValueWord>,
284 right_prefix: &str,
285 ) -> HashMap<String, ValueWord> {
286 let mut merged = left.clone();
287
288 for (k, v) in right {
289 merged.insert(format!("{}.{}", right_prefix, k), v.clone());
290 }
291
292 merged
293 }
294
295 fn merge_with_nulls(
297 left: &HashMap<String, ValueWord>,
298 right_sample: &[HashMap<String, ValueWord>],
299 right_prefix: &str,
300 ) -> HashMap<String, ValueWord> {
301 let mut merged = left.clone();
302
303 if let Some(first_right) = right_sample.first() {
305 for k in first_right.keys() {
306 merged.insert(format!("{}.{}", right_prefix, k), ValueWord::none());
307 }
308 }
309
310 merged
311 }
312
313 fn merge_with_nulls_left(
315 left_sample: &[HashMap<String, ValueWord>],
316 right: &HashMap<String, ValueWord>,
317 right_prefix: &str,
318 ) -> HashMap<String, ValueWord> {
319 let mut merged = HashMap::new();
320
321 if let Some(first_left) = left_sample.first() {
323 for k in first_left.keys() {
324 merged.insert(k.clone(), ValueWord::none());
325 }
326 }
327
328 for (k, v) in right {
330 merged.insert(format!("{}.{}", right_prefix, k), v.clone());
331 }
332
333 merged
334 }
335}
336
337fn nb_values_equal(a: &ValueWord, b: &ValueWord) -> bool {
339 use shape_value::NanTag;
340 match (a.tag(), b.tag()) {
341 (NanTag::F64, NanTag::F64)
342 | (NanTag::I48, NanTag::I48)
343 | (NanTag::F64, NanTag::I48)
344 | (NanTag::I48, NanTag::F64) => {
345 if let (Some(an), Some(bn)) = (a.as_f64(), b.as_f64()) {
346 if an.is_nan() && bn.is_nan() {
347 true
348 } else {
349 (an - bn).abs() < f64::EPSILON
350 }
351 } else {
352 false
353 }
354 }
355 (NanTag::Heap, NanTag::Heap) => {
356 if let (Some(sa), Some(sb)) = (a.as_str(), b.as_str()) {
357 sa == sb
358 } else {
359 false
360 }
361 }
362 (NanTag::Bool, NanTag::Bool) => a.as_bool() == b.as_bool(),
363 (NanTag::None, NanTag::None) => true,
364 _ => {
365 if let (Some(ta), Some(tb)) = (a.as_time(), b.as_time()) {
367 ta == tb
368 } else {
369 false
370 }
371 }
372 }
373}
374
375fn extract_timestamp_nb(v: &ValueWord) -> Option<f64> {
377 if let Some(n) = v.as_f64() {
378 Some(n)
379 } else if let Some(t) = v.as_time() {
380 Some(t.timestamp_millis() as f64)
381 } else {
382 None
383 }
384}
385
386#[cfg(test)]
387mod tests {
388 use super::*;
389 use crate::context::ExecutionContext;
390 use shape_ast::ast::JoinSource;
391
392 fn make_rows(data: Vec<Vec<(&str, ValueWord)>>) -> Vec<HashMap<String, ValueWord>> {
393 data.into_iter()
394 .map(|row| row.into_iter().map(|(k, v)| (k.to_string(), v)).collect())
395 .collect()
396 }
397
398 #[test]
399 fn test_inner_join_using() {
400 let mut ctx = ExecutionContext::new_empty();
401
402 let left = make_rows(vec![
403 vec![
404 ("id", ValueWord::from_f64(1.0)),
405 (
406 "name",
407 ValueWord::from_string(std::sync::Arc::new("A".to_string())),
408 ),
409 ],
410 vec![
411 ("id", ValueWord::from_f64(2.0)),
412 (
413 "name",
414 ValueWord::from_string(std::sync::Arc::new("B".to_string())),
415 ),
416 ],
417 vec![
418 ("id", ValueWord::from_f64(3.0)),
419 (
420 "name",
421 ValueWord::from_string(std::sync::Arc::new("C".to_string())),
422 ),
423 ],
424 ]);
425
426 let right = make_rows(vec![
427 vec![
428 ("id", ValueWord::from_f64(1.0)),
429 ("value", ValueWord::from_f64(100.0)),
430 ],
431 vec![
432 ("id", ValueWord::from_f64(3.0)),
433 ("value", ValueWord::from_f64(300.0)),
434 ],
435 ]);
436
437 let join = JoinClause {
438 join_type: JoinType::Inner,
439 right: JoinSource::Named("test".to_string()),
440 condition: JoinCondition::Using(vec!["id".to_string()]),
441 };
442
443 let result = JoinExecutor::execute(left, right, &join, &mut ctx).unwrap();
444
445 assert_eq!(result.len(), 2);
447
448 assert_eq!(result[0].get("id").map(|v| v.as_f64()), Some(Some(1.0)));
450 assert_eq!(result[0].get("name").and_then(|v| v.as_str()), Some("A"));
451 assert_eq!(
452 result[0].get("right.value").map(|v| v.as_f64()),
453 Some(Some(100.0))
454 );
455 }
456
457 #[test]
458 fn test_left_join() {
459 let mut ctx = ExecutionContext::new_empty();
460
461 let left = make_rows(vec![
462 vec![("id", ValueWord::from_f64(1.0))],
463 vec![("id", ValueWord::from_f64(2.0))],
464 ]);
465
466 let right = make_rows(vec![vec![
467 ("id", ValueWord::from_f64(1.0)),
468 ("val", ValueWord::from_f64(10.0)),
469 ]]);
470
471 let join = JoinClause {
472 join_type: JoinType::Left,
473 right: JoinSource::Named("test".to_string()),
474 condition: JoinCondition::Using(vec!["id".to_string()]),
475 };
476
477 let result = JoinExecutor::execute(left, right, &join, &mut ctx).unwrap();
478
479 assert_eq!(result.len(), 2);
481
482 assert_eq!(
484 result[0].get("right.val").map(|v| v.as_f64()),
485 Some(Some(10.0))
486 );
487
488 assert!(
490 result[1]
491 .get("right.val")
492 .map(|v| v.is_none())
493 .unwrap_or(false)
494 );
495 }
496
497 #[test]
498 fn test_cross_join() {
499 let left = make_rows(vec![
500 vec![("a", ValueWord::from_f64(1.0))],
501 vec![("a", ValueWord::from_f64(2.0))],
502 ]);
503
504 let right = make_rows(vec![
505 vec![("b", ValueWord::from_f64(10.0))],
506 vec![("b", ValueWord::from_f64(20.0))],
507 ]);
508
509 let result = JoinExecutor::cross_join(left, right).unwrap();
510
511 assert_eq!(result.len(), 4);
513 }
514}