Skip to main content

sh_layer1/
event_bus.rs

1//! 事件总线模块
2//!
3//! 发布订阅、事件溯源、持久化。
4
5use parking_lot::RwLock;
6use serde::{Deserialize, Serialize};
7use std::collections::HashMap;
8use std::sync::atomic::{AtomicUsize, Ordering};
9
10/// 全局 handler ID 计数器
11static HANDLER_ID_COUNTER: AtomicUsize = AtomicUsize::new(0);
12
13/// Handler ID 用于取消订阅
14#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
15pub struct HandlerId(usize);
16
17/// 事件
18#[derive(Debug, Clone, Serialize, Deserialize)]
19pub struct Event {
20    pub id: String,
21    pub event_type: String,
22    pub payload: String,
23    pub timestamp: String,
24}
25
26/// 事件处理器(带 ID)
27struct HandlerEntry {
28    id: HandlerId,
29    handler: Box<dyn Fn(&Event) + Send + Sync>,
30}
31
32/// 事件总线
33pub struct EventBus {
34    handlers: RwLock<HashMap<String, Vec<HandlerEntry>>>,
35}
36
37impl EventBus {
38    pub fn new() -> Self {
39        Self {
40            handlers: RwLock::new(HashMap::new()),
41        }
42    }
43
44    /// 订阅事件,返回 handler ID 用于取消订阅
45    pub fn subscribe<F>(&self, event_type: &str, handler: F) -> HandlerId
46    where
47        F: Fn(&Event) + Send + Sync + 'static,
48    {
49        let id = HandlerId(HANDLER_ID_COUNTER.fetch_add(1, Ordering::Relaxed));
50        let entry = HandlerEntry {
51            id,
52            handler: Box::new(handler),
53        };
54        self.handlers
55            .write()
56            .entry(event_type.to_string())
57            .or_default()
58            .push(entry);
59        id
60    }
61
62    /// 取消订阅
63    pub fn unsubscribe(&self, event_type: &str, handler_id: HandlerId) -> bool {
64        let mut handlers = self.handlers.write();
65
66        // 获取条目,执行 retain,并记录结果
67        let (removed, is_empty) = if let Some(entries) = handlers.get_mut(event_type) {
68            let original_len = entries.len();
69            entries.retain(|e| e.id != handler_id);
70            let new_len = entries.len();
71            (original_len > new_len, new_len == 0)
72        } else {
73            return false;
74        };
75
76        // 现在可以安全地移除 key(entries 的借用已结束)
77        if is_empty {
78            handlers.remove(event_type);
79        }
80
81        removed
82    }
83
84    /// 取消某事件类型的所有订阅
85    pub fn unsubscribe_all(&self, event_type: &str) -> usize {
86        let mut handlers = self.handlers.write();
87        handlers.remove(event_type).map(|v| v.len()).unwrap_or(0)
88    }
89
90    /// 发布事件
91    pub fn publish(&self, event: &Event) {
92        if let Some(handlers) = self.handlers.read().get(&event.event_type) {
93            for entry in handlers {
94                (entry.handler)(event);
95            }
96        }
97    }
98
99    /// 获取某事件类型的订阅数量
100    pub fn subscriber_count(&self, event_type: &str) -> usize {
101        self.handlers
102            .read()
103            .get(event_type)
104            .map(|v| v.len())
105            .unwrap_or(0)
106    }
107
108    /// 获取所有事件类型的订阅总数
109    pub fn total_subscribers(&self) -> usize {
110        self.handlers.read().values().map(|v| v.len()).sum()
111    }
112}
113
114impl Default for EventBus {
115    fn default() -> Self {
116        Self::new()
117    }
118}
119
120#[cfg(test)]
121mod tests {
122    use super::*;
123    use std::sync::{Arc, Mutex};
124
125    #[test]
126    fn test_subscribe_and_publish() {
127        let bus = EventBus::new();
128        let received = Arc::new(Mutex::new(String::new()));
129        let received_clone = Arc::clone(&received);
130
131        bus.subscribe("test_event", move |event| {
132            *received_clone.lock().unwrap() = event.payload.clone();
133        });
134
135        let event = Event {
136            id: "1".to_string(),
137            event_type: "test_event".to_string(),
138            payload: "hello world".to_string(),
139            timestamp: "2024-01-01T00:00:00Z".to_string(),
140        };
141
142        bus.publish(&event);
143
144        assert_eq!(*received.lock().unwrap(), "hello world");
145    }
146
147    #[test]
148    fn test_multiple_subscribers() {
149        let bus = EventBus::new();
150        let counter = Arc::new(Mutex::new(0));
151        let counter1 = Arc::clone(&counter);
152        let counter2 = Arc::clone(&counter);
153
154        bus.subscribe("increment", move |_| {
155            *counter1.lock().unwrap() += 1;
156        });
157
158        bus.subscribe("increment", move |_| {
159            *counter2.lock().unwrap() += 10;
160        });
161
162        let event = Event {
163            id: "1".to_string(),
164            event_type: "increment".to_string(),
165            payload: String::new(),
166            timestamp: String::new(),
167        };
168
169        bus.publish(&event);
170
171        assert_eq!(*counter.lock().unwrap(), 11);
172    }
173
174    #[test]
175    fn test_no_subscribers() {
176        let bus = EventBus::new();
177
178        let event = Event {
179            id: "1".to_string(),
180            event_type: "unknown_event".to_string(),
181            payload: String::new(),
182            timestamp: String::new(),
183        };
184
185        // 应该不崩溃
186        bus.publish(&event);
187    }
188
189    #[test]
190    fn test_different_event_types() {
191        let bus = EventBus::new();
192        let results = Arc::new(Mutex::new(Vec::new()));
193        let r1 = Arc::clone(&results);
194        let r2 = Arc::clone(&results);
195
196        bus.subscribe("event_a", move |_| {
197            r1.lock().unwrap().push("A");
198        });
199
200        bus.subscribe("event_b", move |_| {
201            r2.lock().unwrap().push("B");
202        });
203
204        let event_a = Event {
205            id: "1".to_string(),
206            event_type: "event_a".to_string(),
207            payload: String::new(),
208            timestamp: String::new(),
209        };
210
211        let event_b = Event {
212            id: "2".to_string(),
213            event_type: "event_b".to_string(),
214            payload: String::new(),
215            timestamp: String::new(),
216        };
217
218        bus.publish(&event_a);
219        bus.publish(&event_b);
220
221        let res = results.lock().unwrap();
222        assert_eq!(*res, vec!["A", "B"]);
223    }
224
225    #[test]
226    fn test_event_serialization() {
227        let event = Event {
228            id: "123".to_string(),
229            event_type: "test".to_string(),
230            payload: "data".to_string(),
231            timestamp: "2024-01-01T00:00:00Z".to_string(),
232        };
233
234        let json = serde_json::to_string(&event).unwrap();
235        assert!(json.contains("123"));
236        assert!(json.contains("test"));
237        assert!(json.contains("data"));
238    }
239
240    #[test]
241    fn test_event_deserialization() {
242        let json = r#"{
243            "id": "456",
244            "event_type": "my_event",
245            "payload": "my_payload",
246            "timestamp": "2024-01-01T00:00:00Z"
247        }"#;
248
249        let event: Event = serde_json::from_str(json).unwrap();
250        assert_eq!(event.id, "456");
251        assert_eq!(event.event_type, "my_event");
252        assert_eq!(event.payload, "my_payload");
253    }
254
255    #[test]
256    fn test_default_event_bus() {
257        let bus = EventBus::default();
258        let event = Event {
259            id: "1".to_string(),
260            event_type: "test".to_string(),
261            payload: String::new(),
262            timestamp: String::new(),
263        };
264
265        bus.publish(&event); // 应该不崩溃
266    }
267
268    #[test]
269    fn test_concurrent_publish() {
270        use std::sync::atomic::{AtomicUsize, Ordering};
271        use std::thread;
272
273        let bus = Arc::new(EventBus::new());
274        let counter = Arc::new(AtomicUsize::new(0));
275
276        let c1 = Arc::clone(&counter);
277        bus.subscribe("count", move |_| {
278            c1.fetch_add(1, Ordering::SeqCst);
279        });
280
281        let mut handles = vec![];
282        for _ in 0..10 {
283            let b = Arc::clone(&bus);
284            handles.push(thread::spawn(move || {
285                let event = Event {
286                    id: "1".to_string(),
287                    event_type: "count".to_string(),
288                    payload: String::new(),
289                    timestamp: String::new(),
290                };
291                b.publish(&event);
292            }));
293        }
294
295        for h in handles {
296            h.join().unwrap();
297        }
298
299        assert_eq!(counter.load(Ordering::SeqCst), 10);
300    }
301
302    #[test]
303    fn test_event_with_empty_payload() {
304        let bus = EventBus::new();
305        let received = Arc::new(Mutex::new(false));
306        let r = Arc::clone(&received);
307
308        bus.subscribe("empty", move |_| {
309            *r.lock().unwrap() = true;
310        });
311
312        let event = Event {
313            id: "1".to_string(),
314            event_type: "empty".to_string(),
315            payload: String::new(),
316            timestamp: String::new(),
317        };
318
319        bus.publish(&event);
320
321        assert!(*received.lock().unwrap());
322    }
323
324    #[test]
325    fn test_unsubscribe() {
326        let bus = EventBus::new();
327        let counter = Arc::new(Mutex::new(0));
328        let c1 = Arc::clone(&counter);
329
330        let handler_id = bus.subscribe("test", move |_| {
331            *c1.lock().unwrap() += 1;
332        });
333
334        assert_eq!(bus.subscriber_count("test"), 1);
335
336        let event = Event {
337            id: "1".to_string(),
338            event_type: "test".to_string(),
339            payload: String::new(),
340            timestamp: String::new(),
341        };
342
343        bus.publish(&event);
344        assert_eq!(*counter.lock().unwrap(), 1);
345
346        // 取消订阅
347        assert!(bus.unsubscribe("test", handler_id));
348        assert_eq!(bus.subscriber_count("test"), 0);
349
350        // 再次发布不应该触发
351        bus.publish(&event);
352        assert_eq!(*counter.lock().unwrap(), 1);
353
354        // 重复取消应该返回 false
355        assert!(!bus.unsubscribe("test", handler_id));
356    }
357
358    #[test]
359    fn test_unsubscribe_all() {
360        let bus = EventBus::new();
361
362        bus.subscribe("a", |_| {});
363        bus.subscribe("a", |_| {});
364        bus.subscribe("b", |_| {});
365
366        assert_eq!(bus.total_subscribers(), 3);
367
368        let removed = bus.unsubscribe_all("a");
369        assert_eq!(removed, 2);
370        assert_eq!(bus.total_subscribers(), 1);
371        assert_eq!(bus.subscriber_count("a"), 0);
372        assert_eq!(bus.subscriber_count("b"), 1);
373    }
374
375    #[test]
376    fn test_handler_id_unique() {
377        let bus = EventBus::new();
378        let id1 = bus.subscribe("test", |_| {});
379        let id2 = bus.subscribe("test", |_| {});
380        assert_ne!(id1, id2);
381    }
382}