1use crate::streaming::event::StreamEvent;
2use std::collections::{HashMap, VecDeque};
3use std::time::Duration;
4
5#[derive(Debug, Clone, PartialEq)]
7pub enum JoinType {
8 Inner,
10 LeftOuter,
12 RightOuter,
14 FullOuter,
16}
17
18#[derive(Debug, Clone, PartialEq)]
20pub enum JoinStrategy {
21 TimeWindow { duration: Duration },
23 CountWindow { count: usize },
25 SessionWindow { gap: Duration },
27}
28
29#[derive(Debug, Clone)]
31pub struct JoinedEvent {
32 pub left: Option<StreamEvent>,
33 pub right: Option<StreamEvent>,
34 pub join_timestamp: i64,
35}
36
37pub struct StreamJoinNode {
40 pub left_stream: String,
42 pub right_stream: String,
44 pub join_type: JoinType,
46 pub join_strategy: JoinStrategy,
48 pub left_key_extractor: Box<dyn Fn(&StreamEvent) -> Option<String> + Send + Sync>,
50 pub right_key_extractor: Box<dyn Fn(&StreamEvent) -> Option<String> + Send + Sync>,
52 pub join_condition: Box<dyn Fn(&StreamEvent, &StreamEvent) -> bool + Send + Sync>,
54 left_buffer: HashMap<String, VecDeque<StreamEvent>>,
56 right_buffer: HashMap<String, VecDeque<StreamEvent>>,
58 left_matched: HashMap<String, bool>,
60 right_matched: HashMap<String, bool>,
62 watermark: i64,
64}
65
66impl StreamJoinNode {
67 pub fn new(
69 left_stream: String,
70 right_stream: String,
71 join_type: JoinType,
72 join_strategy: JoinStrategy,
73 left_key_extractor: Box<dyn Fn(&StreamEvent) -> Option<String> + Send + Sync>,
74 right_key_extractor: Box<dyn Fn(&StreamEvent) -> Option<String> + Send + Sync>,
75 join_condition: Box<dyn Fn(&StreamEvent, &StreamEvent) -> bool + Send + Sync>,
76 ) -> Self {
77 Self {
78 left_stream,
79 right_stream,
80 join_type,
81 join_strategy,
82 left_key_extractor,
83 right_key_extractor,
84 join_condition,
85 left_buffer: HashMap::new(),
86 right_buffer: HashMap::new(),
87 left_matched: HashMap::new(),
88 right_matched: HashMap::new(),
89 watermark: 0,
90 }
91 }
92
93 pub fn process_left(&mut self, event: StreamEvent) -> Vec<JoinedEvent> {
95 let mut results = Vec::new();
96
97 let key = match (self.left_key_extractor)(&event) {
99 Some(k) => k,
100 None => return results, };
102
103 let event_id = Self::generate_event_id(&event);
104
105 self.left_buffer
107 .entry(key.clone())
108 .or_insert_with(VecDeque::new)
109 .push_back(event.clone());
110
111 if let Some(right_events) = self.right_buffer.get(&key) {
113 for right_event in right_events {
114 if self.is_within_window(&event, right_event)
115 && (self.join_condition)(&event, right_event)
116 {
117 results.push(JoinedEvent {
118 left: Some(event.clone()),
119 right: Some(right_event.clone()),
120 join_timestamp: (event.metadata.timestamp as i64)
121 .max(right_event.metadata.timestamp as i64),
122 });
123
124 self.left_matched.insert(event_id.clone(), true);
126 self.right_matched
127 .insert(Self::generate_event_id(right_event), true);
128 }
129 }
130 }
131
132 if (self.join_type == JoinType::LeftOuter || self.join_type == JoinType::FullOuter)
134 && !self.left_matched.contains_key(&event_id)
135 {
136 results.push(JoinedEvent {
137 left: Some(event.clone()),
138 right: None,
139 join_timestamp: event.metadata.timestamp as i64,
140 });
141 }
142
143 results
144 }
145
146 pub fn process_right(&mut self, event: StreamEvent) -> Vec<JoinedEvent> {
148 let mut results = Vec::new();
149
150 let key = match (self.right_key_extractor)(&event) {
152 Some(k) => k,
153 None => return results, };
155
156 let event_id = Self::generate_event_id(&event);
157
158 self.right_buffer
160 .entry(key.clone())
161 .or_insert_with(VecDeque::new)
162 .push_back(event.clone());
163
164 if let Some(left_events) = self.left_buffer.get(&key) {
166 for left_event in left_events {
167 if self.is_within_window(left_event, &event)
168 && (self.join_condition)(left_event, &event)
169 {
170 results.push(JoinedEvent {
171 left: Some(left_event.clone()),
172 right: Some(event.clone()),
173 join_timestamp: (left_event.metadata.timestamp as i64)
174 .max(event.metadata.timestamp as i64),
175 });
176
177 self.left_matched
179 .insert(Self::generate_event_id(left_event), true);
180 self.right_matched.insert(event_id.clone(), true);
181 }
182 }
183 }
184
185 if (self.join_type == JoinType::RightOuter || self.join_type == JoinType::FullOuter)
187 && !self.right_matched.contains_key(&event_id)
188 {
189 results.push(JoinedEvent {
190 left: None,
191 right: Some(event.clone()),
192 join_timestamp: event.metadata.timestamp as i64,
193 });
194 }
195
196 results
197 }
198
199 pub fn update_watermark(&mut self, new_watermark: i64) -> Vec<JoinedEvent> {
201 let mut results = Vec::new();
202 self.watermark = new_watermark;
203
204 self.evict_expired_events();
206
207 if self.join_type == JoinType::LeftOuter || self.join_type == JoinType::FullOuter {
209 results.extend(self.emit_unmatched_left());
210 }
211 if self.join_type == JoinType::RightOuter || self.join_type == JoinType::FullOuter {
212 results.extend(self.emit_unmatched_right());
213 }
214
215 results
216 }
217
218 fn is_within_window(&self, left: &StreamEvent, right: &StreamEvent) -> bool {
220 match &self.join_strategy {
221 JoinStrategy::TimeWindow { duration } => {
222 let time_diff =
223 ((left.metadata.timestamp as i64) - (right.metadata.timestamp as i64)).abs();
224 time_diff <= duration.as_millis() as i64
225 }
226 JoinStrategy::CountWindow { .. } => {
227 true
229 }
230 JoinStrategy::SessionWindow { gap } => {
231 let time_diff =
232 ((left.metadata.timestamp as i64) - (right.metadata.timestamp as i64)).abs();
233 time_diff <= gap.as_millis() as i64
234 }
235 }
236 }
237
238 fn evict_expired_events(&mut self) {
240 let watermark = self.watermark;
241 let window_size = self.get_window_duration();
242
243 for queue in self.left_buffer.values_mut() {
245 while let Some(event) = queue.front() {
246 if watermark - event.metadata.timestamp as i64 > window_size {
247 if let Some(evicted) = queue.pop_front() {
248 let id = Self::generate_event_id(&evicted);
249 self.left_matched.remove(&id);
250 }
251 } else {
252 break;
253 }
254 }
255 }
256
257 for queue in self.right_buffer.values_mut() {
259 while let Some(event) = queue.front() {
260 if watermark - event.metadata.timestamp as i64 > window_size {
261 if let Some(evicted) = queue.pop_front() {
262 let id = Self::generate_event_id(&evicted);
263 self.right_matched.remove(&id);
264 }
265 } else {
266 break;
267 }
268 }
269 }
270
271 self.left_buffer.retain(|_, queue| !queue.is_empty());
273 self.right_buffer.retain(|_, queue| !queue.is_empty());
274 }
275
276 fn emit_unmatched_left(&mut self) -> Vec<JoinedEvent> {
278 let mut results = Vec::new();
279 let watermark = self.watermark;
280 let window_size = self.get_window_duration();
281
282 for queue in self.left_buffer.values() {
283 for event in queue {
284 let id = Self::generate_event_id(event);
285 if !self.left_matched.contains_key(&id)
286 && watermark - event.metadata.timestamp as i64 > window_size
287 {
288 results.push(JoinedEvent {
289 left: Some(event.clone()),
290 right: None,
291 join_timestamp: event.metadata.timestamp as i64,
292 });
293 }
294 }
295 }
296
297 results
298 }
299
300 fn emit_unmatched_right(&mut self) -> Vec<JoinedEvent> {
302 let mut results = Vec::new();
303 let watermark = self.watermark;
304 let window_size = self.get_window_duration();
305
306 for queue in self.right_buffer.values() {
307 for event in queue {
308 let id = Self::generate_event_id(event);
309 if !self.right_matched.contains_key(&id)
310 && watermark - event.metadata.timestamp as i64 > window_size
311 {
312 results.push(JoinedEvent {
313 left: None,
314 right: Some(event.clone()),
315 join_timestamp: event.metadata.timestamp as i64,
316 });
317 }
318 }
319 }
320
321 results
322 }
323
324 fn get_window_duration(&self) -> i64 {
326 match &self.join_strategy {
327 JoinStrategy::TimeWindow { duration } => duration.as_millis() as i64,
328 JoinStrategy::SessionWindow { gap } => gap.as_millis() as i64,
329 JoinStrategy::CountWindow { .. } => i64::MAX, }
331 }
332
333 fn generate_event_id(event: &StreamEvent) -> String {
335 format!("{}_{}", event.id, event.metadata.timestamp as i64)
336 }
337
338 pub fn get_stats(&self) -> JoinNodeStats {
340 let left_count: usize = self.left_buffer.values().map(|q| q.len()).sum();
341 let right_count: usize = self.right_buffer.values().map(|q| q.len()).sum();
342
343 JoinNodeStats {
344 left_buffer_size: left_count,
345 right_buffer_size: right_count,
346 left_partitions: self.left_buffer.len(),
347 right_partitions: self.right_buffer.len(),
348 watermark: self.watermark,
349 }
350 }
351}
352
353#[derive(Debug, Clone)]
355pub struct JoinNodeStats {
356 pub left_buffer_size: usize,
357 pub right_buffer_size: usize,
358 pub left_partitions: usize,
359 pub right_partitions: usize,
360 pub watermark: i64,
361}
362
363#[cfg(test)]
364mod tests {
365 use super::*;
366
367 fn create_test_event(stream_id: &str, timestamp: i64, key: &str) -> StreamEvent {
368 use crate::streaming::event::EventMetadata;
369 use crate::types::Value;
370
371 StreamEvent {
372 id: format!("test_{}", timestamp),
373 event_type: "test".to_string(),
374 data: vec![(key.to_string(), Value::String(key.to_string()))]
375 .into_iter()
376 .collect(),
377 metadata: EventMetadata {
378 timestamp: timestamp as u64,
379 source: stream_id.to_string(),
380 sequence: 0,
381 tags: std::collections::HashMap::new(),
382 },
383 }
384 }
385
386 #[test]
387 fn test_inner_join_basic() {
388 let mut join_node = StreamJoinNode::new(
389 "left".to_string(),
390 "right".to_string(),
391 JoinType::Inner,
392 JoinStrategy::TimeWindow {
393 duration: Duration::from_secs(10),
394 },
395 Box::new(|e| e.data.get("key").and_then(|v| v.as_string())),
396 Box::new(|e| e.data.get("key").and_then(|v| v.as_string())),
397 Box::new(|_, _| true),
398 );
399
400 let left_event = create_test_event("left", 1000, "user1");
401 let right_event = create_test_event("right", 1005, "user1");
402
403 let results1 = join_node.process_left(left_event);
404 assert_eq!(results1.len(), 0); let results2 = join_node.process_right(right_event);
407 assert_eq!(results2.len(), 1); assert!(results2[0].left.is_some());
409 assert!(results2[0].right.is_some());
410 }
411
412 #[test]
413 fn test_time_window_filtering() {
414 let mut join_node = StreamJoinNode::new(
415 "left".to_string(),
416 "right".to_string(),
417 JoinType::Inner,
418 JoinStrategy::TimeWindow {
419 duration: Duration::from_secs(5),
420 },
421 Box::new(|e| e.data.get("key").and_then(|v| v.as_string())),
422 Box::new(|e| e.data.get("key").and_then(|v| v.as_string())),
423 Box::new(|_, _| true),
424 );
425
426 let left_event = create_test_event("left", 1000, "user1");
427 let right_event_close = create_test_event("right", 1003, "user1");
428 let right_event_far = create_test_event("right", 8000, "user1");
429
430 join_node.process_left(left_event);
431
432 let results1 = join_node.process_right(right_event_close);
433 assert_eq!(results1.len(), 1); let results2 = join_node.process_right(right_event_far);
436 assert_eq!(results2.len(), 0); }
438
439 #[test]
440 fn test_left_outer_join() {
441 let mut join_node = StreamJoinNode::new(
442 "left".to_string(),
443 "right".to_string(),
444 JoinType::LeftOuter,
445 JoinStrategy::TimeWindow {
446 duration: Duration::from_secs(10),
447 },
448 Box::new(|e| e.data.get("key").and_then(|v| v.as_string())),
449 Box::new(|e| e.data.get("key").and_then(|v| v.as_string())),
450 Box::new(|_, _| true),
451 );
452
453 let left_event = create_test_event("left", 1000, "user1");
454
455 let results = join_node.process_left(left_event);
456 assert_eq!(results.len(), 1); assert!(results[0].left.is_some());
458 assert!(results[0].right.is_none());
459 }
460
461 #[test]
462 fn test_partition_by_key() {
463 let mut join_node = StreamJoinNode::new(
464 "left".to_string(),
465 "right".to_string(),
466 JoinType::Inner,
467 JoinStrategy::TimeWindow {
468 duration: Duration::from_secs(10),
469 },
470 Box::new(|e| e.data.get("key").and_then(|v| v.as_string())),
471 Box::new(|e| e.data.get("key").and_then(|v| v.as_string())),
472 Box::new(|_, _| true),
473 );
474
475 let left1 = create_test_event("left", 1000, "user1");
476 let left2 = create_test_event("left", 1000, "user2");
477 let right1 = create_test_event("right", 1005, "user1");
478
479 join_node.process_left(left1);
480 join_node.process_left(left2);
481
482 let results = join_node.process_right(right1);
483 assert_eq!(results.len(), 1); assert_eq!(
485 results[0]
486 .left
487 .as_ref()
488 .unwrap()
489 .data
490 .get("key")
491 .unwrap()
492 .as_string()
493 .unwrap(),
494 "user1"
495 );
496 }
497}