rust_rule_engine/streaming/
join_manager.rs

1use crate::rete::stream_join_node::{JoinedEvent, StreamJoinNode};
2use crate::streaming::event::StreamEvent;
3use std::collections::HashMap;
4use std::sync::{Arc, Mutex};
5
6/// Manages multiple stream joins and coordinates event routing
7pub struct StreamJoinManager {
8    /// All registered join nodes, indexed by join ID
9    joins: HashMap<String, Arc<Mutex<StreamJoinNode>>>,
10    /// Maps stream names to the join nodes that consume them
11    stream_to_joins: HashMap<String, Vec<String>>,
12    /// Result handlers for each join
13    result_handlers: HashMap<String, Box<dyn Fn(JoinedEvent) + Send + Sync>>,
14}
15
16impl StreamJoinManager {
17    /// Create a new stream join manager
18    pub fn new() -> Self {
19        Self {
20            joins: HashMap::new(),
21            stream_to_joins: HashMap::new(),
22            result_handlers: HashMap::new(),
23        }
24    }
25
26    /// Register a new stream join
27    pub fn register_join(
28        &mut self,
29        join_id: String,
30        join_node: StreamJoinNode,
31        result_handler: Box<dyn Fn(JoinedEvent) + Send + Sync>,
32    ) {
33        let left_stream = join_node.left_stream.clone();
34        let right_stream = join_node.right_stream.clone();
35
36        // Index the join by streams
37        self.stream_to_joins
38            .entry(left_stream)
39            .or_default()
40            .push(join_id.clone());
41
42        self.stream_to_joins
43            .entry(right_stream)
44            .or_default()
45            .push(join_id.clone());
46
47        // Store join node and handler
48        self.joins
49            .insert(join_id.clone(), Arc::new(Mutex::new(join_node)));
50        self.result_handlers.insert(join_id, result_handler);
51    }
52
53    /// Remove a stream join
54    pub fn unregister_join(&mut self, join_id: &str) {
55        if let Some(join) = self.joins.get(join_id) {
56            let join_lock = join.lock().unwrap();
57            let left_stream = join_lock.left_stream.clone();
58            let right_stream = join_lock.right_stream.clone();
59
60            // Remove from stream indices
61            if let Some(joins) = self.stream_to_joins.get_mut(&left_stream) {
62                joins.retain(|id| id != join_id);
63            }
64            if let Some(joins) = self.stream_to_joins.get_mut(&right_stream) {
65                joins.retain(|id| id != join_id);
66            }
67        }
68
69        self.joins.remove(join_id);
70        self.result_handlers.remove(join_id);
71    }
72
73    /// Process an incoming stream event
74    /// Routes the event to all relevant join nodes
75    pub fn process_event(&self, event: StreamEvent) {
76        let stream_id = event.metadata.source.clone();
77
78        // Find all joins that consume this stream
79        if let Some(join_ids) = self.stream_to_joins.get(&stream_id) {
80            for join_id in join_ids {
81                if let Some(join) = self.joins.get(join_id) {
82                    let mut join_lock = join.lock().unwrap();
83
84                    // Determine if this is a left or right stream for this join
85                    let results = if join_lock.left_stream == stream_id {
86                        join_lock.process_left(event.clone())
87                    } else {
88                        join_lock.process_right(event.clone())
89                    };
90
91                    // Process results
92                    if let Some(handler) = self.result_handlers.get(join_id) {
93                        for joined in results {
94                            handler(joined);
95                        }
96                    }
97                }
98            }
99        }
100    }
101
102    /// Update watermark for a specific stream
103    /// This triggers eviction of old events and emission of outer join results
104    pub fn update_watermark(&self, stream_id: &str, watermark: i64) {
105        if let Some(join_ids) = self.stream_to_joins.get(stream_id) {
106            for join_id in join_ids {
107                if let Some(join) = self.joins.get(join_id) {
108                    let mut join_lock = join.lock().unwrap();
109                    let results = join_lock.update_watermark(watermark);
110
111                    // Process results from watermark update (outer join emissions)
112                    if let Some(handler) = self.result_handlers.get(join_id) {
113                        for joined in results {
114                            handler(joined);
115                        }
116                    }
117                }
118            }
119        }
120    }
121
122    /// Get statistics for all joins
123    pub fn get_all_stats(&self) -> HashMap<String, crate::rete::stream_join_node::JoinNodeStats> {
124        let mut stats = HashMap::new();
125        for (join_id, join) in &self.joins {
126            let join_lock = join.lock().unwrap();
127            stats.insert(join_id.clone(), join_lock.get_stats());
128        }
129        stats
130    }
131
132    /// Get statistics for a specific join
133    pub fn get_join_stats(
134        &self,
135        join_id: &str,
136    ) -> Option<crate::rete::stream_join_node::JoinNodeStats> {
137        self.joins.get(join_id).map(|join| {
138            let join_lock = join.lock().unwrap();
139            join_lock.get_stats()
140        })
141    }
142
143    /// Clear all joins (for testing or reset)
144    pub fn clear(&mut self) {
145        self.joins.clear();
146        self.stream_to_joins.clear();
147        self.result_handlers.clear();
148    }
149}
150
151impl Default for StreamJoinManager {
152    fn default() -> Self {
153        Self::new()
154    }
155}
156
157#[cfg(test)]
158mod tests {
159    use super::*;
160    use crate::rete::stream_join_node::{JoinStrategy, JoinType};
161    use std::sync::atomic::{AtomicUsize, Ordering};
162    use std::time::Duration;
163
164    fn create_test_event(stream_id: &str, timestamp: i64, user_id: &str) -> StreamEvent {
165        use crate::streaming::event::EventMetadata;
166        use crate::types::Value;
167
168        StreamEvent {
169            id: format!("test_{}_{}", stream_id, timestamp),
170            event_type: "test".to_string(),
171            data: vec![("user_id".to_string(), Value::String(user_id.to_string()))]
172                .into_iter()
173                .collect(),
174            metadata: EventMetadata {
175                timestamp: timestamp as u64,
176                source: stream_id.to_string(),
177                sequence: 0,
178                tags: HashMap::new(),
179            },
180        }
181    }
182
183    #[test]
184    fn test_register_and_route_events() {
185        let mut manager = StreamJoinManager::new();
186        let result_count = Arc::new(AtomicUsize::new(0));
187        let result_count_clone = result_count.clone();
188
189        let join_node = StreamJoinNode::new(
190            "left".to_string(),
191            "right".to_string(),
192            JoinType::Inner,
193            JoinStrategy::TimeWindow {
194                duration: Duration::from_secs(10),
195            },
196            Box::new(|e| e.data.get("user_id").and_then(|v| v.as_string())),
197            Box::new(|e| e.data.get("user_id").and_then(|v| v.as_string())),
198            Box::new(|_, _| true),
199        );
200
201        manager.register_join(
202            "join1".to_string(),
203            join_node,
204            Box::new(move |_| {
205                result_count_clone.fetch_add(1, Ordering::SeqCst);
206            }),
207        );
208
209        // Send events
210        let left_event = create_test_event("left", 1000, "user1");
211        let right_event = create_test_event("right", 1005, "user1");
212
213        manager.process_event(left_event);
214        manager.process_event(right_event);
215
216        // Should have one join result
217        assert_eq!(result_count.load(Ordering::SeqCst), 1);
218    }
219
220    #[test]
221    fn test_multiple_joins_same_stream() {
222        let mut manager = StreamJoinManager::new();
223        let result_count1 = Arc::new(AtomicUsize::new(0));
224        let result_count2 = Arc::new(AtomicUsize::new(0));
225        let rc1 = result_count1.clone();
226        let rc2 = result_count2.clone();
227
228        // Join 1: left + right
229        let join1 = StreamJoinNode::new(
230            "left".to_string(),
231            "right".to_string(),
232            JoinType::Inner,
233            JoinStrategy::TimeWindow {
234                duration: Duration::from_secs(10),
235            },
236            Box::new(|e| e.data.get("user_id").and_then(|v| v.as_string())),
237            Box::new(|e| e.data.get("user_id").and_then(|v| v.as_string())),
238            Box::new(|_, _| true),
239        );
240
241        // Join 2: left + other
242        let join2 = StreamJoinNode::new(
243            "left".to_string(),
244            "other".to_string(),
245            JoinType::Inner,
246            JoinStrategy::TimeWindow {
247                duration: Duration::from_secs(10),
248            },
249            Box::new(|e| e.data.get("user_id").and_then(|v| v.as_string())),
250            Box::new(|e| e.data.get("user_id").and_then(|v| v.as_string())),
251            Box::new(|_, _| true),
252        );
253
254        manager.register_join(
255            "join1".to_string(),
256            join1,
257            Box::new(move |_| {
258                rc1.fetch_add(1, Ordering::SeqCst);
259            }),
260        );
261
262        manager.register_join(
263            "join2".to_string(),
264            join2,
265            Box::new(move |_| {
266                rc2.fetch_add(1, Ordering::SeqCst);
267            }),
268        );
269
270        // Send left event (should be routed to both joins)
271        let left_event = create_test_event("left", 1000, "user1");
272        manager.process_event(left_event);
273
274        // Send right event (should only go to join1)
275        let right_event = create_test_event("right", 1005, "user1");
276        manager.process_event(right_event);
277
278        // Send other event (should only go to join2)
279        let other_event = create_test_event("other", 1005, "user1");
280        manager.process_event(other_event);
281
282        // Each join should have one result
283        assert_eq!(result_count1.load(Ordering::SeqCst), 1);
284        assert_eq!(result_count2.load(Ordering::SeqCst), 1);
285    }
286
287    #[test]
288    fn test_unregister_join() {
289        let mut manager = StreamJoinManager::new();
290        let result_count = Arc::new(AtomicUsize::new(0));
291        let rc = result_count.clone();
292
293        let join_node = StreamJoinNode::new(
294            "left".to_string(),
295            "right".to_string(),
296            JoinType::Inner,
297            JoinStrategy::TimeWindow {
298                duration: Duration::from_secs(10),
299            },
300            Box::new(|e| e.data.get("user_id").and_then(|v| v.as_string())),
301            Box::new(|e| e.data.get("user_id").and_then(|v| v.as_string())),
302            Box::new(|_, _| true),
303        );
304
305        manager.register_join(
306            "join1".to_string(),
307            join_node,
308            Box::new(move |_| {
309                rc.fetch_add(1, Ordering::SeqCst);
310            }),
311        );
312
313        // Unregister the join
314        manager.unregister_join("join1");
315
316        // Send events - should not produce results
317        let left_event = create_test_event("left", 1000, "user1");
318        let right_event = create_test_event("right", 1005, "user1");
319
320        manager.process_event(left_event);
321        manager.process_event(right_event);
322
323        assert_eq!(result_count.load(Ordering::SeqCst), 0);
324    }
325
326    #[test]
327    fn test_watermark_update() {
328        let mut manager = StreamJoinManager::new();
329        let result_count = Arc::new(AtomicUsize::new(0));
330        let rc = result_count.clone();
331
332        // Use left outer join to test watermark emissions
333        let join_node = StreamJoinNode::new(
334            "left".to_string(),
335            "right".to_string(),
336            JoinType::LeftOuter,
337            JoinStrategy::TimeWindow {
338                duration: Duration::from_secs(5),
339            },
340            Box::new(|e| e.data.get("user_id").and_then(|v| v.as_string())),
341            Box::new(|e| e.data.get("user_id").and_then(|v| v.as_string())),
342            Box::new(|_, _| true),
343        );
344
345        manager.register_join(
346            "join1".to_string(),
347            join_node,
348            Box::new(move |_| {
349                rc.fetch_add(1, Ordering::SeqCst);
350            }),
351        );
352
353        // Send left event
354        let left_event = create_test_event("left", 1000, "user1");
355        manager.process_event(left_event);
356
357        // At this point, left outer join should have emitted unmatched left event
358        assert_eq!(result_count.load(Ordering::SeqCst), 1);
359
360        // Update watermark - might emit more for outer joins
361        manager.update_watermark("left", 10000);
362
363        // Should still be 1 (event already emitted)
364        assert_eq!(result_count.load(Ordering::SeqCst), 1);
365    }
366}