rust_rule_engine/streaming/
join_manager.rs1use crate::rete::stream_join_node::{JoinedEvent, StreamJoinNode};
2use crate::streaming::event::StreamEvent;
3use std::collections::HashMap;
4use std::sync::{Arc, Mutex};
5
6pub struct StreamJoinManager {
8 joins: HashMap<String, Arc<Mutex<StreamJoinNode>>>,
10 stream_to_joins: HashMap<String, Vec<String>>,
12 result_handlers: HashMap<String, Box<dyn Fn(JoinedEvent) + Send + Sync>>,
14}
15
16impl StreamJoinManager {
17 pub fn new() -> Self {
19 Self {
20 joins: HashMap::new(),
21 stream_to_joins: HashMap::new(),
22 result_handlers: HashMap::new(),
23 }
24 }
25
26 pub fn register_join(
28 &mut self,
29 join_id: String,
30 join_node: StreamJoinNode,
31 result_handler: Box<dyn Fn(JoinedEvent) + Send + Sync>,
32 ) {
33 let left_stream = join_node.left_stream.clone();
34 let right_stream = join_node.right_stream.clone();
35
36 self.stream_to_joins
38 .entry(left_stream)
39 .or_insert_with(Vec::new)
40 .push(join_id.clone());
41
42 self.stream_to_joins
43 .entry(right_stream)
44 .or_insert_with(Vec::new)
45 .push(join_id.clone());
46
47 self.joins
49 .insert(join_id.clone(), Arc::new(Mutex::new(join_node)));
50 self.result_handlers.insert(join_id, result_handler);
51 }
52
53 pub fn unregister_join(&mut self, join_id: &str) {
55 if let Some(join) = self.joins.get(join_id) {
56 let join_lock = join.lock().unwrap();
57 let left_stream = join_lock.left_stream.clone();
58 let right_stream = join_lock.right_stream.clone();
59
60 if let Some(joins) = self.stream_to_joins.get_mut(&left_stream) {
62 joins.retain(|id| id != join_id);
63 }
64 if let Some(joins) = self.stream_to_joins.get_mut(&right_stream) {
65 joins.retain(|id| id != join_id);
66 }
67 }
68
69 self.joins.remove(join_id);
70 self.result_handlers.remove(join_id);
71 }
72
73 pub fn process_event(&self, event: StreamEvent) {
76 let stream_id = event.metadata.source.clone();
77
78 if let Some(join_ids) = self.stream_to_joins.get(&stream_id) {
80 for join_id in join_ids {
81 if let Some(join) = self.joins.get(join_id) {
82 let mut join_lock = join.lock().unwrap();
83
84 let results = if join_lock.left_stream == stream_id {
86 join_lock.process_left(event.clone())
87 } else {
88 join_lock.process_right(event.clone())
89 };
90
91 if let Some(handler) = self.result_handlers.get(join_id) {
93 for joined in results {
94 handler(joined);
95 }
96 }
97 }
98 }
99 }
100 }
101
102 pub fn update_watermark(&self, stream_id: &str, watermark: i64) {
105 if let Some(join_ids) = self.stream_to_joins.get(stream_id) {
106 for join_id in join_ids {
107 if let Some(join) = self.joins.get(join_id) {
108 let mut join_lock = join.lock().unwrap();
109 let results = join_lock.update_watermark(watermark);
110
111 if let Some(handler) = self.result_handlers.get(join_id) {
113 for joined in results {
114 handler(joined);
115 }
116 }
117 }
118 }
119 }
120 }
121
122 pub fn get_all_stats(&self) -> HashMap<String, crate::rete::stream_join_node::JoinNodeStats> {
124 let mut stats = HashMap::new();
125 for (join_id, join) in &self.joins {
126 let join_lock = join.lock().unwrap();
127 stats.insert(join_id.clone(), join_lock.get_stats());
128 }
129 stats
130 }
131
132 pub fn get_join_stats(
134 &self,
135 join_id: &str,
136 ) -> Option<crate::rete::stream_join_node::JoinNodeStats> {
137 self.joins.get(join_id).map(|join| {
138 let join_lock = join.lock().unwrap();
139 join_lock.get_stats()
140 })
141 }
142
143 pub fn clear(&mut self) {
145 self.joins.clear();
146 self.stream_to_joins.clear();
147 self.result_handlers.clear();
148 }
149}
150
151impl Default for StreamJoinManager {
152 fn default() -> Self {
153 Self::new()
154 }
155}
156
157#[cfg(test)]
158mod tests {
159 use super::*;
160 use crate::rete::stream_join_node::{JoinStrategy, JoinType};
161 use std::sync::atomic::{AtomicUsize, Ordering};
162 use std::time::Duration;
163
164 fn create_test_event(stream_id: &str, timestamp: i64, user_id: &str) -> StreamEvent {
165 use crate::streaming::event::EventMetadata;
166 use crate::types::Value;
167
168 StreamEvent {
169 id: format!("test_{}_{}", stream_id, timestamp),
170 event_type: "test".to_string(),
171 data: vec![("user_id".to_string(), Value::String(user_id.to_string()))]
172 .into_iter()
173 .collect(),
174 metadata: EventMetadata {
175 timestamp: timestamp as u64,
176 source: stream_id.to_string(),
177 sequence: 0,
178 tags: HashMap::new(),
179 },
180 }
181 }
182
183 #[test]
184 fn test_register_and_route_events() {
185 let mut manager = StreamJoinManager::new();
186 let result_count = Arc::new(AtomicUsize::new(0));
187 let result_count_clone = result_count.clone();
188
189 let join_node = StreamJoinNode::new(
190 "left".to_string(),
191 "right".to_string(),
192 JoinType::Inner,
193 JoinStrategy::TimeWindow {
194 duration: Duration::from_secs(10),
195 },
196 Box::new(|e| e.data.get("user_id").and_then(|v| v.as_string())),
197 Box::new(|e| e.data.get("user_id").and_then(|v| v.as_string())),
198 Box::new(|_, _| true),
199 );
200
201 manager.register_join(
202 "join1".to_string(),
203 join_node,
204 Box::new(move |_| {
205 result_count_clone.fetch_add(1, Ordering::SeqCst);
206 }),
207 );
208
209 let left_event = create_test_event("left", 1000, "user1");
211 let right_event = create_test_event("right", 1005, "user1");
212
213 manager.process_event(left_event);
214 manager.process_event(right_event);
215
216 assert_eq!(result_count.load(Ordering::SeqCst), 1);
218 }
219
220 #[test]
221 fn test_multiple_joins_same_stream() {
222 let mut manager = StreamJoinManager::new();
223 let result_count1 = Arc::new(AtomicUsize::new(0));
224 let result_count2 = Arc::new(AtomicUsize::new(0));
225 let rc1 = result_count1.clone();
226 let rc2 = result_count2.clone();
227
228 let join1 = StreamJoinNode::new(
230 "left".to_string(),
231 "right".to_string(),
232 JoinType::Inner,
233 JoinStrategy::TimeWindow {
234 duration: Duration::from_secs(10),
235 },
236 Box::new(|e| e.data.get("user_id").and_then(|v| v.as_string())),
237 Box::new(|e| e.data.get("user_id").and_then(|v| v.as_string())),
238 Box::new(|_, _| true),
239 );
240
241 let join2 = StreamJoinNode::new(
243 "left".to_string(),
244 "other".to_string(),
245 JoinType::Inner,
246 JoinStrategy::TimeWindow {
247 duration: Duration::from_secs(10),
248 },
249 Box::new(|e| e.data.get("user_id").and_then(|v| v.as_string())),
250 Box::new(|e| e.data.get("user_id").and_then(|v| v.as_string())),
251 Box::new(|_, _| true),
252 );
253
254 manager.register_join(
255 "join1".to_string(),
256 join1,
257 Box::new(move |_| {
258 rc1.fetch_add(1, Ordering::SeqCst);
259 }),
260 );
261
262 manager.register_join(
263 "join2".to_string(),
264 join2,
265 Box::new(move |_| {
266 rc2.fetch_add(1, Ordering::SeqCst);
267 }),
268 );
269
270 let left_event = create_test_event("left", 1000, "user1");
272 manager.process_event(left_event);
273
274 let right_event = create_test_event("right", 1005, "user1");
276 manager.process_event(right_event);
277
278 let other_event = create_test_event("other", 1005, "user1");
280 manager.process_event(other_event);
281
282 assert_eq!(result_count1.load(Ordering::SeqCst), 1);
284 assert_eq!(result_count2.load(Ordering::SeqCst), 1);
285 }
286
287 #[test]
288 fn test_unregister_join() {
289 let mut manager = StreamJoinManager::new();
290 let result_count = Arc::new(AtomicUsize::new(0));
291 let rc = result_count.clone();
292
293 let join_node = StreamJoinNode::new(
294 "left".to_string(),
295 "right".to_string(),
296 JoinType::Inner,
297 JoinStrategy::TimeWindow {
298 duration: Duration::from_secs(10),
299 },
300 Box::new(|e| e.data.get("user_id").and_then(|v| v.as_string())),
301 Box::new(|e| e.data.get("user_id").and_then(|v| v.as_string())),
302 Box::new(|_, _| true),
303 );
304
305 manager.register_join(
306 "join1".to_string(),
307 join_node,
308 Box::new(move |_| {
309 rc.fetch_add(1, Ordering::SeqCst);
310 }),
311 );
312
313 manager.unregister_join("join1");
315
316 let left_event = create_test_event("left", 1000, "user1");
318 let right_event = create_test_event("right", 1005, "user1");
319
320 manager.process_event(left_event);
321 manager.process_event(right_event);
322
323 assert_eq!(result_count.load(Ordering::SeqCst), 0);
324 }
325
326 #[test]
327 fn test_watermark_update() {
328 let mut manager = StreamJoinManager::new();
329 let result_count = Arc::new(AtomicUsize::new(0));
330 let rc = result_count.clone();
331
332 let join_node = StreamJoinNode::new(
334 "left".to_string(),
335 "right".to_string(),
336 JoinType::LeftOuter,
337 JoinStrategy::TimeWindow {
338 duration: Duration::from_secs(5),
339 },
340 Box::new(|e| e.data.get("user_id").and_then(|v| v.as_string())),
341 Box::new(|e| e.data.get("user_id").and_then(|v| v.as_string())),
342 Box::new(|_, _| true),
343 );
344
345 manager.register_join(
346 "join1".to_string(),
347 join_node,
348 Box::new(move |_| {
349 rc.fetch_add(1, Ordering::SeqCst);
350 }),
351 );
352
353 let left_event = create_test_event("left", 1000, "user1");
355 manager.process_event(left_event);
356
357 assert_eq!(result_count.load(Ordering::SeqCst), 1);
359
360 manager.update_watermark("left", 10000);
362
363 assert_eq!(result_count.load(Ordering::SeqCst), 1);
365 }
366}