1use crate::streaming::event::StreamEvent;
2use std::collections::{HashMap, VecDeque};
3use std::time::Duration;
4
5type KeyExtractor = Box<dyn Fn(&StreamEvent) -> Option<String> + Send + Sync>;
7type JoinCondition = Box<dyn Fn(&StreamEvent, &StreamEvent) -> bool + Send + Sync>;
8
9#[derive(Debug, Clone, PartialEq)]
11pub enum JoinType {
12 Inner,
14 LeftOuter,
16 RightOuter,
18 FullOuter,
20}
21
22#[derive(Debug, Clone, PartialEq)]
24pub enum JoinStrategy {
25 TimeWindow { duration: Duration },
27 CountWindow { count: usize },
29 SessionWindow { gap: Duration },
31}
32
33#[derive(Debug, Clone)]
35pub struct JoinedEvent {
36 pub left: Option<StreamEvent>,
37 pub right: Option<StreamEvent>,
38 pub join_timestamp: i64,
39}
40
41pub struct StreamJoinNode {
44 pub left_stream: String,
46 pub right_stream: String,
48 pub join_type: JoinType,
50 pub join_strategy: JoinStrategy,
52 pub left_key_extractor: KeyExtractor,
54 pub right_key_extractor: KeyExtractor,
56 pub join_condition: JoinCondition,
58 left_buffer: HashMap<String, VecDeque<StreamEvent>>,
60 right_buffer: HashMap<String, VecDeque<StreamEvent>>,
62 left_matched: HashMap<String, bool>,
64 right_matched: HashMap<String, bool>,
66 watermark: i64,
68}
69
70impl StreamJoinNode {
71 pub fn new(
73 left_stream: String,
74 right_stream: String,
75 join_type: JoinType,
76 join_strategy: JoinStrategy,
77 left_key_extractor: KeyExtractor,
78 right_key_extractor: KeyExtractor,
79 join_condition: JoinCondition,
80 ) -> Self {
81 Self {
82 left_stream,
83 right_stream,
84 join_type,
85 join_strategy,
86 left_key_extractor,
87 right_key_extractor,
88 join_condition,
89 left_buffer: HashMap::new(),
90 right_buffer: HashMap::new(),
91 left_matched: HashMap::new(),
92 right_matched: HashMap::new(),
93 watermark: 0,
94 }
95 }
96
97 pub fn process_left(&mut self, event: StreamEvent) -> Vec<JoinedEvent> {
99 let mut results = Vec::new();
100
101 let key = match (self.left_key_extractor)(&event) {
103 Some(k) => k,
104 None => return results, };
106
107 let event_id = Self::generate_event_id(&event);
108
109 self.left_buffer
111 .entry(key.clone())
112 .or_default()
113 .push_back(event.clone());
114
115 if let Some(right_events) = self.right_buffer.get(&key) {
117 for right_event in right_events {
118 if self.is_within_window(&event, right_event)
119 && (self.join_condition)(&event, right_event)
120 {
121 results.push(JoinedEvent {
122 left: Some(event.clone()),
123 right: Some(right_event.clone()),
124 join_timestamp: (event.metadata.timestamp as i64)
125 .max(right_event.metadata.timestamp as i64),
126 });
127
128 self.left_matched.insert(event_id.clone(), true);
130 self.right_matched
131 .insert(Self::generate_event_id(right_event), true);
132 }
133 }
134 }
135
136 if (self.join_type == JoinType::LeftOuter || self.join_type == JoinType::FullOuter)
138 && !self.left_matched.contains_key(&event_id)
139 {
140 results.push(JoinedEvent {
141 left: Some(event.clone()),
142 right: None,
143 join_timestamp: event.metadata.timestamp as i64,
144 });
145 }
146
147 results
148 }
149
150 pub fn process_right(&mut self, event: StreamEvent) -> Vec<JoinedEvent> {
152 let mut results = Vec::new();
153
154 let key = match (self.right_key_extractor)(&event) {
156 Some(k) => k,
157 None => return results, };
159
160 let event_id = Self::generate_event_id(&event);
161
162 self.right_buffer
164 .entry(key.clone())
165 .or_default()
166 .push_back(event.clone());
167
168 if let Some(left_events) = self.left_buffer.get(&key) {
170 for left_event in left_events {
171 if self.is_within_window(left_event, &event)
172 && (self.join_condition)(left_event, &event)
173 {
174 results.push(JoinedEvent {
175 left: Some(left_event.clone()),
176 right: Some(event.clone()),
177 join_timestamp: (left_event.metadata.timestamp as i64)
178 .max(event.metadata.timestamp as i64),
179 });
180
181 self.left_matched
183 .insert(Self::generate_event_id(left_event), true);
184 self.right_matched.insert(event_id.clone(), true);
185 }
186 }
187 }
188
189 if (self.join_type == JoinType::RightOuter || self.join_type == JoinType::FullOuter)
191 && !self.right_matched.contains_key(&event_id)
192 {
193 results.push(JoinedEvent {
194 left: None,
195 right: Some(event.clone()),
196 join_timestamp: event.metadata.timestamp as i64,
197 });
198 }
199
200 results
201 }
202
203 pub fn update_watermark(&mut self, new_watermark: i64) -> Vec<JoinedEvent> {
205 let mut results = Vec::new();
206 self.watermark = new_watermark;
207
208 if matches!(
212 self.join_type,
213 JoinType::Inner | JoinType::LeftOuter | JoinType::RightOuter | JoinType::FullOuter
214 ) {
215 for (key, left_queue) in &self.left_buffer {
216 if let Some(right_queue) = self.right_buffer.get(key) {
217 for left_event in left_queue {
218 for right_event in right_queue {
219 if self.is_within_window(left_event, right_event)
220 && (self.join_condition)(left_event, right_event)
221 {
222 let left_id = Self::generate_event_id(left_event);
223 let right_id = Self::generate_event_id(right_event);
224
225 if !self.left_matched.contains_key(&left_id)
227 || !self.right_matched.contains_key(&right_id)
228 {
229 results.push(JoinedEvent {
230 left: Some(left_event.clone()),
231 right: Some(right_event.clone()),
232 join_timestamp: (left_event.metadata.timestamp as i64)
233 .max(right_event.metadata.timestamp as i64),
234 });
235
236 self.left_matched.insert(left_id.clone(), true);
238 self.right_matched.insert(right_id.clone(), true);
239 }
240 }
241 }
242 }
243 }
244 }
245 }
246
247 self.evict_expired_events();
249
250 if self.join_type == JoinType::LeftOuter || self.join_type == JoinType::FullOuter {
252 results.extend(self.emit_unmatched_left());
253 }
254 if self.join_type == JoinType::RightOuter || self.join_type == JoinType::FullOuter {
255 results.extend(self.emit_unmatched_right());
256 }
257
258 results
259 }
260
261 fn is_within_window(&self, left: &StreamEvent, right: &StreamEvent) -> bool {
263 match &self.join_strategy {
264 JoinStrategy::TimeWindow { duration } => {
265 let time_diff =
267 ((left.metadata.timestamp as i64) - (right.metadata.timestamp as i64)).abs();
268 time_diff <= duration.as_secs() as i64
269 }
270 JoinStrategy::CountWindow { .. } => {
271 true
273 }
274 JoinStrategy::SessionWindow { gap } => {
275 let time_diff =
277 ((left.metadata.timestamp as i64) - (right.metadata.timestamp as i64)).abs();
278 time_diff <= gap.as_secs() as i64
279 }
280 }
281 }
282
283 fn evict_expired_events(&mut self) {
285 let watermark = self.watermark;
286 let window_size = self.get_window_duration();
287
288 for queue in self.left_buffer.values_mut() {
290 while let Some(event) = queue.front() {
291 if watermark - event.metadata.timestamp as i64 > window_size {
292 if let Some(evicted) = queue.pop_front() {
293 let id = Self::generate_event_id(&evicted);
294 self.left_matched.remove(&id);
295 }
296 } else {
297 break;
298 }
299 }
300 }
301
302 for queue in self.right_buffer.values_mut() {
304 while let Some(event) = queue.front() {
305 if watermark - event.metadata.timestamp as i64 > window_size {
306 if let Some(evicted) = queue.pop_front() {
307 let id = Self::generate_event_id(&evicted);
308 self.right_matched.remove(&id);
309 }
310 } else {
311 break;
312 }
313 }
314 }
315
316 self.left_buffer.retain(|_, queue| !queue.is_empty());
318 self.right_buffer.retain(|_, queue| !queue.is_empty());
319 }
320
321 fn emit_unmatched_left(&mut self) -> Vec<JoinedEvent> {
323 let mut results = Vec::new();
324 let watermark = self.watermark;
325 let window_size = self.get_window_duration();
326
327 for queue in self.left_buffer.values() {
328 for event in queue {
329 let id = Self::generate_event_id(event);
330 if !self.left_matched.contains_key(&id)
331 && watermark - event.metadata.timestamp as i64 > window_size
332 {
333 results.push(JoinedEvent {
334 left: Some(event.clone()),
335 right: None,
336 join_timestamp: event.metadata.timestamp as i64,
337 });
338 }
339 }
340 }
341
342 results
343 }
344
345 fn emit_unmatched_right(&mut self) -> Vec<JoinedEvent> {
347 let mut results = Vec::new();
348 let watermark = self.watermark;
349 let window_size = self.get_window_duration();
350
351 for queue in self.right_buffer.values() {
352 for event in queue {
353 let id = Self::generate_event_id(event);
354 if !self.right_matched.contains_key(&id)
355 && watermark - event.metadata.timestamp as i64 > window_size
356 {
357 results.push(JoinedEvent {
358 left: None,
359 right: Some(event.clone()),
360 join_timestamp: event.metadata.timestamp as i64,
361 });
362 }
363 }
364 }
365
366 results
367 }
368
369 fn get_window_duration(&self) -> i64 {
371 match &self.join_strategy {
372 JoinStrategy::TimeWindow { duration } => duration.as_secs() as i64,
374 JoinStrategy::SessionWindow { gap } => gap.as_secs() as i64,
375 JoinStrategy::CountWindow { .. } => i64::MAX, }
377 }
378
379 fn generate_event_id(event: &StreamEvent) -> String {
381 format!("{}_{}", event.id, event.metadata.timestamp as i64)
382 }
383
384 pub fn get_stats(&self) -> JoinNodeStats {
386 let left_count: usize = self.left_buffer.values().map(|q| q.len()).sum();
387 let right_count: usize = self.right_buffer.values().map(|q| q.len()).sum();
388
389 JoinNodeStats {
390 left_buffer_size: left_count,
391 right_buffer_size: right_count,
392 left_partitions: self.left_buffer.len(),
393 right_partitions: self.right_buffer.len(),
394 watermark: self.watermark,
395 }
396 }
397}
398
399#[derive(Debug, Clone)]
401pub struct JoinNodeStats {
402 pub left_buffer_size: usize,
403 pub right_buffer_size: usize,
404 pub left_partitions: usize,
405 pub right_partitions: usize,
406 pub watermark: i64,
407}
408
409#[cfg(test)]
410mod tests {
411 use super::*;
412
413 fn create_test_event(stream_id: &str, timestamp: i64, key: &str) -> StreamEvent {
414 use crate::streaming::event::EventMetadata;
415 use crate::types::Value;
416
417 StreamEvent {
418 id: format!("test_{}", timestamp),
419 event_type: "test".to_string(),
420 data: vec![("key".to_string(), Value::String(key.to_string()))]
422 .into_iter()
423 .collect(),
424 metadata: EventMetadata {
425 timestamp: timestamp as u64,
426 source: stream_id.to_string(),
427 sequence: 0,
428 tags: std::collections::HashMap::new(),
429 },
430 }
431 }
432
433 #[test]
434 fn test_inner_join_basic() {
435 let mut join_node = StreamJoinNode::new(
436 "left".to_string(),
437 "right".to_string(),
438 JoinType::Inner,
439 JoinStrategy::TimeWindow {
440 duration: Duration::from_secs(10),
441 },
442 Box::new(|e| e.data.get("key").and_then(|v| v.as_string())),
443 Box::new(|e| e.data.get("key").and_then(|v| v.as_string())),
444 Box::new(|_, _| true),
445 );
446
447 let left_event = create_test_event("left", 1000, "user1");
448 let right_event = create_test_event("right", 1005, "user1");
449
450 let results1 = join_node.process_left(left_event);
451 assert_eq!(results1.len(), 0); eprintln!(
455 "left_buffer keys: {:?}",
456 join_node.left_buffer.keys().collect::<Vec<_>>()
457 );
458 eprintln!(
459 "right_buffer keys: {:?}",
460 join_node.right_buffer.keys().collect::<Vec<_>>()
461 );
462
463 let results2 = join_node.process_right(right_event);
464 eprintln!("results2.len() = {}", results2.len());
465 assert_eq!(results2.len(), 1); assert!(results2[0].left.is_some());
467 assert!(results2[0].right.is_some());
468 }
469
470 #[test]
471 fn test_time_window_filtering() {
472 let mut join_node = StreamJoinNode::new(
473 "left".to_string(),
474 "right".to_string(),
475 JoinType::Inner,
476 JoinStrategy::TimeWindow {
477 duration: Duration::from_secs(5),
478 },
479 Box::new(|e| e.data.get("key").and_then(|v| v.as_string())),
480 Box::new(|e| e.data.get("key").and_then(|v| v.as_string())),
481 Box::new(|_, _| true),
482 );
483
484 let left_event = create_test_event("left", 1000, "user1");
485 let right_event_close = create_test_event("right", 1003, "user1");
486 let right_event_far = create_test_event("right", 8000, "user1");
487
488 join_node.process_left(left_event);
489
490 let results1 = join_node.process_right(right_event_close);
491 assert_eq!(results1.len(), 1); let results2 = join_node.process_right(right_event_far);
494 assert_eq!(results2.len(), 0); }
496
497 #[test]
498 fn test_left_outer_join() {
499 let mut join_node = StreamJoinNode::new(
500 "left".to_string(),
501 "right".to_string(),
502 JoinType::LeftOuter,
503 JoinStrategy::TimeWindow {
504 duration: Duration::from_secs(10),
505 },
506 Box::new(|e| e.data.get("key").and_then(|v| v.as_string())),
507 Box::new(|e| e.data.get("key").and_then(|v| v.as_string())),
508 Box::new(|_, _| true),
509 );
510
511 let left_event = create_test_event("left", 1000, "user1");
512
513 let results = join_node.process_left(left_event);
514 assert_eq!(results.len(), 1); assert!(results[0].left.is_some());
516 assert!(results[0].right.is_none());
517 }
518
519 #[test]
520 fn test_partition_by_key() {
521 let mut join_node = StreamJoinNode::new(
522 "left".to_string(),
523 "right".to_string(),
524 JoinType::Inner,
525 JoinStrategy::TimeWindow {
526 duration: Duration::from_secs(10),
527 },
528 Box::new(|e| e.data.get("key").and_then(|v| v.as_string())),
529 Box::new(|e| e.data.get("key").and_then(|v| v.as_string())),
530 Box::new(|_, _| true),
531 );
532
533 let left1 = create_test_event("left", 1000, "user1");
534 let left2 = create_test_event("left", 1000, "user2");
535 let right1 = create_test_event("right", 1005, "user1");
536
537 join_node.process_left(left1);
538 join_node.process_left(left2);
539
540 let results = join_node.process_right(right1);
541 assert_eq!(results.len(), 1); assert_eq!(
543 results[0]
544 .left
545 .as_ref()
546 .unwrap()
547 .data
548 .get("key")
549 .unwrap()
550 .as_string()
551 .unwrap(),
552 "user1"
553 );
554 }
555}