Skip to main content

rust_rule_engine/rete/
stream_join_node.rs

1use crate::streaming::event::StreamEvent;
2use std::collections::{HashMap, VecDeque};
3use std::time::Duration;
4
5// Type aliases to reduce type complexity warnings from clippy
6type KeyExtractor = Box<dyn Fn(&StreamEvent) -> Option<String> + Send + Sync>;
7type JoinCondition = Box<dyn Fn(&StreamEvent, &StreamEvent) -> bool + Send + Sync>;
8
9/// Join types supported by the stream join node
10#[derive(Debug, Clone, PartialEq)]
11pub enum JoinType {
12    /// Inner join - only emit when events match in both streams
13    Inner,
14    /// Left outer join - emit left events even if no right match
15    LeftOuter,
16    /// Right outer join - emit right events even if no left match
17    RightOuter,
18    /// Full outer join - emit all events from both streams
19    FullOuter,
20}
21
22/// Strategy for buffering and matching stream events
23#[derive(Debug, Clone, PartialEq)]
24pub enum JoinStrategy {
25    /// Time-based window join (most common for streaming)
26    TimeWindow { duration: Duration },
27    /// Count-based window join
28    CountWindow { count: usize },
29    /// Session-based join with gap timeout
30    SessionWindow { gap: Duration },
31}
32
33/// Represents a matched pair of events from two streams
34#[derive(Debug, Clone)]
35pub struct JoinedEvent {
36    pub left: Option<StreamEvent>,
37    pub right: Option<StreamEvent>,
38    pub join_timestamp: i64,
39}
40
41/// Stream join node for RETE network
42/// Performs windowed joins between two streams based on join conditions
43pub struct StreamJoinNode {
44    /// Name of the left input stream
45    pub left_stream: String,
46    /// Name of the right input stream
47    pub right_stream: String,
48    /// Join type (inner, left outer, right outer, full outer)
49    pub join_type: JoinType,
50    /// Join strategy (time window, count window, session window)
51    pub join_strategy: JoinStrategy,
52    /// Join key extractor for left stream
53    pub left_key_extractor: KeyExtractor,
54    /// Join key extractor for right stream
55    pub right_key_extractor: KeyExtractor,
56    /// Additional join condition predicate
57    pub join_condition: JoinCondition,
58    /// Buffer for left stream events, partitioned by join key
59    left_buffer: HashMap<String, VecDeque<StreamEvent>>,
60    /// Buffer for right stream events, partitioned by join key
61    right_buffer: HashMap<String, VecDeque<StreamEvent>>,
62    /// Tracking for which left events have been matched (for outer joins)
63    left_matched: HashMap<String, bool>,
64    /// Tracking for which right events have been matched (for outer joins)
65    right_matched: HashMap<String, bool>,
66    /// Current watermark timestamp
67    watermark: i64,
68}
69
70impl StreamJoinNode {
71    /// Create a new stream join node
72    pub fn new(
73        left_stream: String,
74        right_stream: String,
75        join_type: JoinType,
76        join_strategy: JoinStrategy,
77        left_key_extractor: KeyExtractor,
78        right_key_extractor: KeyExtractor,
79        join_condition: JoinCondition,
80    ) -> Self {
81        Self {
82            left_stream,
83            right_stream,
84            join_type,
85            join_strategy,
86            left_key_extractor,
87            right_key_extractor,
88            join_condition,
89            left_buffer: HashMap::new(),
90            right_buffer: HashMap::new(),
91            left_matched: HashMap::new(),
92            right_matched: HashMap::new(),
93            watermark: 0,
94        }
95    }
96
97    /// Process a left stream event and produce joined events
98    pub fn process_left(&mut self, event: StreamEvent) -> Vec<JoinedEvent> {
99        let mut results = Vec::new();
100
101        // Extract join key
102        let key = match (self.left_key_extractor)(&event) {
103            Some(k) => k,
104            None => return results, // No key, skip
105        };
106
107        let event_id = Self::generate_event_id(&event);
108
109        // Add to buffer
110        self.left_buffer
111            .entry(key.clone())
112            .or_default()
113            .push_back(event.clone());
114
115        // Try to join with right stream events
116        if let Some(right_events) = self.right_buffer.get(&key) {
117            for right_event in right_events {
118                if self.is_within_window(&event, right_event)
119                    && (self.join_condition)(&event, right_event)
120                {
121                    results.push(JoinedEvent {
122                        left: Some(event.clone()),
123                        right: Some(right_event.clone()),
124                        join_timestamp: (event.metadata.timestamp as i64)
125                            .max(right_event.metadata.timestamp as i64),
126                    });
127
128                    // Mark as matched for outer join tracking
129                    self.left_matched.insert(event_id.clone(), true);
130                    self.right_matched
131                        .insert(Self::generate_event_id(right_event), true);
132                }
133            }
134        }
135
136        // For outer joins, emit unmatched left events
137        if (self.join_type == JoinType::LeftOuter || self.join_type == JoinType::FullOuter)
138            && !self.left_matched.contains_key(&event_id)
139        {
140            results.push(JoinedEvent {
141                left: Some(event.clone()),
142                right: None,
143                join_timestamp: event.metadata.timestamp as i64,
144            });
145        }
146
147        results
148    }
149
150    /// Process a right stream event and produce joined events
151    pub fn process_right(&mut self, event: StreamEvent) -> Vec<JoinedEvent> {
152        let mut results = Vec::new();
153
154        // Extract join key
155        let key = match (self.right_key_extractor)(&event) {
156            Some(k) => k,
157            None => return results, // No key, skip
158        };
159
160        let event_id = Self::generate_event_id(&event);
161
162        // Add to buffer
163        self.right_buffer
164            .entry(key.clone())
165            .or_default()
166            .push_back(event.clone());
167
168        // Try to join with left stream events
169        if let Some(left_events) = self.left_buffer.get(&key) {
170            for left_event in left_events {
171                if self.is_within_window(left_event, &event)
172                    && (self.join_condition)(left_event, &event)
173                {
174                    results.push(JoinedEvent {
175                        left: Some(left_event.clone()),
176                        right: Some(event.clone()),
177                        join_timestamp: (left_event.metadata.timestamp as i64)
178                            .max(event.metadata.timestamp as i64),
179                    });
180
181                    // Mark as matched for outer join tracking
182                    self.left_matched
183                        .insert(Self::generate_event_id(left_event), true);
184                    self.right_matched.insert(event_id.clone(), true);
185                }
186            }
187        }
188
189        // For outer joins, emit unmatched right events
190        if (self.join_type == JoinType::RightOuter || self.join_type == JoinType::FullOuter)
191            && !self.right_matched.contains_key(&event_id)
192        {
193            results.push(JoinedEvent {
194                left: None,
195                right: Some(event.clone()),
196                join_timestamp: event.metadata.timestamp as i64,
197            });
198        }
199
200        results
201    }
202
203    /// Update watermark and evict old events
204    pub fn update_watermark(&mut self, new_watermark: i64) -> Vec<JoinedEvent> {
205        let mut results = Vec::new();
206        self.watermark = new_watermark;
207
208        // First, for join types that should produce matched results (Inner, LeftOuter, RightOuter, FullOuter),
209        // emit all joined pairs currently in the buffers that satisfy the join window and condition.
210        // We do this before eviction so that pairs that are still within window are emitted.
211        if matches!(
212            self.join_type,
213            JoinType::Inner | JoinType::LeftOuter | JoinType::RightOuter | JoinType::FullOuter
214        ) {
215            for (key, left_queue) in &self.left_buffer {
216                if let Some(right_queue) = self.right_buffer.get(key) {
217                    for left_event in left_queue {
218                        for right_event in right_queue {
219                            if self.is_within_window(left_event, right_event)
220                                && (self.join_condition)(left_event, right_event)
221                            {
222                                let left_id = Self::generate_event_id(left_event);
223                                let right_id = Self::generate_event_id(right_event);
224
225                                // Avoid emitting duplicates for events already marked as matched
226                                if !self.left_matched.contains_key(&left_id)
227                                    || !self.right_matched.contains_key(&right_id)
228                                {
229                                    results.push(JoinedEvent {
230                                        left: Some(left_event.clone()),
231                                        right: Some(right_event.clone()),
232                                        join_timestamp: (left_event.metadata.timestamp as i64)
233                                            .max(right_event.metadata.timestamp as i64),
234                                    });
235
236                                    // mark both as matched so outer-unmatched emission won't re-emit them
237                                    self.left_matched.insert(left_id.clone(), true);
238                                    self.right_matched.insert(right_id.clone(), true);
239                                }
240                            }
241                        }
242                    }
243                }
244            }
245        }
246
247        // Evict expired events from buffers
248        self.evict_expired_events();
249
250        // For outer joins, emit any remaining unmatched events that are now beyond the window
251        if self.join_type == JoinType::LeftOuter || self.join_type == JoinType::FullOuter {
252            results.extend(self.emit_unmatched_left());
253        }
254        if self.join_type == JoinType::RightOuter || self.join_type == JoinType::FullOuter {
255            results.extend(self.emit_unmatched_right());
256        }
257
258        results
259    }
260
261    /// Check if two events are within the join window
262    fn is_within_window(&self, left: &StreamEvent, right: &StreamEvent) -> bool {
263        match &self.join_strategy {
264            JoinStrategy::TimeWindow { duration } => {
265                // Compare timestamps and duration in seconds to match test conventions
266                let time_diff =
267                    ((left.metadata.timestamp as i64) - (right.metadata.timestamp as i64)).abs();
268                time_diff <= duration.as_secs() as i64
269            }
270            JoinStrategy::CountWindow { .. } => {
271                // For count windows, we handle this differently in buffer management
272                true
273            }
274            JoinStrategy::SessionWindow { gap } => {
275                // Session gap is compared in seconds
276                let time_diff =
277                    ((left.metadata.timestamp as i64) - (right.metadata.timestamp as i64)).abs();
278                time_diff <= gap.as_secs() as i64
279            }
280        }
281    }
282
283    /// Evict events that are outside the join window
284    fn evict_expired_events(&mut self) {
285        let watermark = self.watermark;
286        let window_size = self.get_window_duration();
287
288        // Evict from left buffer
289        for queue in self.left_buffer.values_mut() {
290            while let Some(event) = queue.front() {
291                if watermark - event.metadata.timestamp as i64 > window_size {
292                    if let Some(evicted) = queue.pop_front() {
293                        let id = Self::generate_event_id(&evicted);
294                        self.left_matched.remove(&id);
295                    }
296                } else {
297                    break;
298                }
299            }
300        }
301
302        // Evict from right buffer
303        for queue in self.right_buffer.values_mut() {
304            while let Some(event) = queue.front() {
305                if watermark - event.metadata.timestamp as i64 > window_size {
306                    if let Some(evicted) = queue.pop_front() {
307                        let id = Self::generate_event_id(&evicted);
308                        self.right_matched.remove(&id);
309                    }
310                } else {
311                    break;
312                }
313            }
314        }
315
316        // Clean up empty queues
317        self.left_buffer.retain(|_, queue| !queue.is_empty());
318        self.right_buffer.retain(|_, queue| !queue.is_empty());
319    }
320
321    /// Emit unmatched left events (for left/full outer joins)
322    fn emit_unmatched_left(&mut self) -> Vec<JoinedEvent> {
323        let mut results = Vec::new();
324        let watermark = self.watermark;
325        let window_size = self.get_window_duration();
326
327        for queue in self.left_buffer.values() {
328            for event in queue {
329                let id = Self::generate_event_id(event);
330                if !self.left_matched.contains_key(&id)
331                    && watermark - event.metadata.timestamp as i64 > window_size
332                {
333                    results.push(JoinedEvent {
334                        left: Some(event.clone()),
335                        right: None,
336                        join_timestamp: event.metadata.timestamp as i64,
337                    });
338                }
339            }
340        }
341
342        results
343    }
344
345    /// Emit unmatched right events (for right/full outer joins)
346    fn emit_unmatched_right(&mut self) -> Vec<JoinedEvent> {
347        let mut results = Vec::new();
348        let watermark = self.watermark;
349        let window_size = self.get_window_duration();
350
351        for queue in self.right_buffer.values() {
352            for event in queue {
353                let id = Self::generate_event_id(event);
354                if !self.right_matched.contains_key(&id)
355                    && watermark - event.metadata.timestamp as i64 > window_size
356                {
357                    results.push(JoinedEvent {
358                        left: None,
359                        right: Some(event.clone()),
360                        join_timestamp: event.metadata.timestamp as i64,
361                    });
362                }
363            }
364        }
365
366        results
367    }
368
369    /// Get window duration in milliseconds
370    fn get_window_duration(&self) -> i64 {
371        match &self.join_strategy {
372            // Return window duration in seconds (consistent with event timestamps used in tests)
373            JoinStrategy::TimeWindow { duration } => duration.as_secs() as i64,
374            JoinStrategy::SessionWindow { gap } => gap.as_secs() as i64,
375            JoinStrategy::CountWindow { .. } => i64::MAX, // Count windows don't time out
376        }
377    }
378
379    /// Generate a unique ID for an event
380    fn generate_event_id(event: &StreamEvent) -> String {
381        format!("{}_{}", event.id, event.metadata.timestamp as i64)
382    }
383
384    /// Get buffer statistics (for monitoring and debugging)
385    pub fn get_stats(&self) -> JoinNodeStats {
386        let left_count: usize = self.left_buffer.values().map(|q| q.len()).sum();
387        let right_count: usize = self.right_buffer.values().map(|q| q.len()).sum();
388
389        JoinNodeStats {
390            left_buffer_size: left_count,
391            right_buffer_size: right_count,
392            left_partitions: self.left_buffer.len(),
393            right_partitions: self.right_buffer.len(),
394            watermark: self.watermark,
395        }
396    }
397}
398
399/// Statistics for join node monitoring
400#[derive(Debug, Clone)]
401pub struct JoinNodeStats {
402    pub left_buffer_size: usize,
403    pub right_buffer_size: usize,
404    pub left_partitions: usize,
405    pub right_partitions: usize,
406    pub watermark: i64,
407}
408
409#[cfg(test)]
410mod tests {
411    use super::*;
412
413    fn create_test_event(stream_id: &str, timestamp: i64, key: &str) -> StreamEvent {
414        use crate::streaming::event::EventMetadata;
415        use crate::types::Value;
416
417        StreamEvent {
418            id: format!("test_{}", timestamp),
419            event_type: "test".to_string(),
420            // Store data under the field name "key" so the key extractor can find it
421            data: vec![("key".to_string(), Value::String(key.to_string()))]
422                .into_iter()
423                .collect(),
424            metadata: EventMetadata {
425                timestamp: timestamp as u64,
426                source: stream_id.to_string(),
427                sequence: 0,
428                tags: std::collections::HashMap::new(),
429            },
430        }
431    }
432
433    #[test]
434    fn test_inner_join_basic() {
435        let mut join_node = StreamJoinNode::new(
436            "left".to_string(),
437            "right".to_string(),
438            JoinType::Inner,
439            JoinStrategy::TimeWindow {
440                duration: Duration::from_secs(10),
441            },
442            Box::new(|e| e.data.get("key").and_then(|v| v.as_string())),
443            Box::new(|e| e.data.get("key").and_then(|v| v.as_string())),
444            Box::new(|_, _| true),
445        );
446
447        let left_event = create_test_event("left", 1000, "user1");
448        let right_event = create_test_event("right", 1005, "user1");
449
450        let results1 = join_node.process_left(left_event);
451        assert_eq!(results1.len(), 0); // No right events yet
452
453        // Debug: inspect buffers before processing right
454        eprintln!(
455            "left_buffer keys: {:?}",
456            join_node.left_buffer.keys().collect::<Vec<_>>()
457        );
458        eprintln!(
459            "right_buffer keys: {:?}",
460            join_node.right_buffer.keys().collect::<Vec<_>>()
461        );
462
463        let results2 = join_node.process_right(right_event);
464        eprintln!("results2.len() = {}", results2.len());
465        assert_eq!(results2.len(), 1); // Should join
466        assert!(results2[0].left.is_some());
467        assert!(results2[0].right.is_some());
468    }
469
470    #[test]
471    fn test_time_window_filtering() {
472        let mut join_node = StreamJoinNode::new(
473            "left".to_string(),
474            "right".to_string(),
475            JoinType::Inner,
476            JoinStrategy::TimeWindow {
477                duration: Duration::from_secs(5),
478            },
479            Box::new(|e| e.data.get("key").and_then(|v| v.as_string())),
480            Box::new(|e| e.data.get("key").and_then(|v| v.as_string())),
481            Box::new(|_, _| true),
482        );
483
484        let left_event = create_test_event("left", 1000, "user1");
485        let right_event_close = create_test_event("right", 1003, "user1");
486        let right_event_far = create_test_event("right", 8000, "user1");
487
488        join_node.process_left(left_event);
489
490        let results1 = join_node.process_right(right_event_close);
491        assert_eq!(results1.len(), 1); // Within window
492
493        let results2 = join_node.process_right(right_event_far);
494        assert_eq!(results2.len(), 0); // Outside window
495    }
496
497    #[test]
498    fn test_left_outer_join() {
499        let mut join_node = StreamJoinNode::new(
500            "left".to_string(),
501            "right".to_string(),
502            JoinType::LeftOuter,
503            JoinStrategy::TimeWindow {
504                duration: Duration::from_secs(10),
505            },
506            Box::new(|e| e.data.get("key").and_then(|v| v.as_string())),
507            Box::new(|e| e.data.get("key").and_then(|v| v.as_string())),
508            Box::new(|_, _| true),
509        );
510
511        let left_event = create_test_event("left", 1000, "user1");
512
513        let results = join_node.process_left(left_event);
514        assert_eq!(results.len(), 1); // Emits unmatched left
515        assert!(results[0].left.is_some());
516        assert!(results[0].right.is_none());
517    }
518
519    #[test]
520    fn test_partition_by_key() {
521        let mut join_node = StreamJoinNode::new(
522            "left".to_string(),
523            "right".to_string(),
524            JoinType::Inner,
525            JoinStrategy::TimeWindow {
526                duration: Duration::from_secs(10),
527            },
528            Box::new(|e| e.data.get("key").and_then(|v| v.as_string())),
529            Box::new(|e| e.data.get("key").and_then(|v| v.as_string())),
530            Box::new(|_, _| true),
531        );
532
533        let left1 = create_test_event("left", 1000, "user1");
534        let left2 = create_test_event("left", 1000, "user2");
535        let right1 = create_test_event("right", 1005, "user1");
536
537        join_node.process_left(left1);
538        join_node.process_left(left2);
539
540        let results = join_node.process_right(right1);
541        assert_eq!(results.len(), 1); // Only joins with user1
542        assert_eq!(
543            results[0]
544                .left
545                .as_ref()
546                .unwrap()
547                .data
548                .get("key")
549                .unwrap()
550                .as_string()
551                .unwrap(),
552            "user1"
553        );
554    }
555}