saorsa_core/events/
mod.rs1use crate::fwid::Key;
20use crate::types::Forward;
21use anyhow::Result;
22use bytes::Bytes;
23use serde::{Deserialize, Serialize};
24use std::collections::HashMap;
25use std::sync::Arc;
26use tokio::sync::{RwLock, broadcast};
27
28pub struct Subscription<T> {
30 receiver: broadcast::Receiver<T>,
31}
32
33impl<T: Clone> Subscription<T> {
34 fn new(receiver: broadcast::Receiver<T>) -> Self {
36 Self { receiver }
37 }
38
39 pub async fn recv(&mut self) -> Result<T> {
41 self.receiver
42 .recv()
43 .await
44 .map_err(|e| anyhow::anyhow!("Subscription error: {}", e))
45 }
46
47 pub fn try_recv(&mut self) -> Result<T> {
49 self.receiver
50 .try_recv()
51 .map_err(|e| anyhow::anyhow!("Subscription error: {}", e))
52 }
53}
54
55#[derive(Debug, Clone, Serialize, Deserialize)]
57pub enum TopologyEvent {
58 PeerJoined { peer_id: Vec<u8>, address: String },
60 PeerLeft { peer_id: Vec<u8>, reason: String },
62 PartitionDetected {
64 partition_id: u64,
65 affected_peers: Vec<Vec<u8>>,
66 },
67 PartitionHealed { partition_id: u64 },
69 RoutingTableUpdated {
71 added: Vec<Vec<u8>>,
72 removed: Vec<Vec<u8>>,
73 },
74}
75
76#[derive(Debug, Clone, Serialize, Deserialize)]
78pub enum DhtWatchEvent {
79 ValueStored { key: Key, value: Vec<u8> },
81 ValueUpdated {
83 key: Key,
84 old_value: Vec<u8>,
85 new_value: Vec<u8>,
86 },
87 ValueDeleted { key: Key },
89 KeyExpired { key: Key },
91}
92
93pub struct EventBus {
97 topology_tx: broadcast::Sender<TopologyEvent>,
99
100 dht_watches: Arc<RwLock<HashMap<Key, broadcast::Sender<Bytes>>>>,
102
103 forward_watches: Arc<RwLock<HashMap<Key, broadcast::Sender<Forward>>>>,
105}
106
107impl EventBus {
108 pub fn new() -> Self {
110 let (topology_tx, _) = broadcast::channel(1000);
111
112 Self {
113 topology_tx,
114 dht_watches: Arc::new(RwLock::new(HashMap::new())),
115 forward_watches: Arc::new(RwLock::new(HashMap::new())),
116 }
117 }
118
119 pub fn subscribe_topology(&self) -> Subscription<TopologyEvent> {
121 Subscription::new(self.topology_tx.subscribe())
122 }
123
124 pub async fn publish_topology(&self, event: TopologyEvent) -> Result<()> {
126 self.topology_tx
127 .send(event)
128 .map_err(|_| anyhow::anyhow!("No topology subscribers"))?;
129 Ok(())
130 }
131
132 pub async fn subscribe_dht_key(&self, key: Key) -> Subscription<Bytes> {
134 let mut watches = self.dht_watches.write().await;
135
136 let tx = watches.entry(key).or_insert_with(|| {
137 let (tx, _) = broadcast::channel(100);
138 tx
139 });
140
141 Subscription::new(tx.subscribe())
142 }
143
144 pub async fn publish_dht_update(&self, key: Key, value: Bytes) -> Result<()> {
146 let watches = self.dht_watches.read().await;
147
148 if let Some(tx) = watches.get(&key) {
149 let _ = tx.send(value); }
151
152 Ok(())
153 }
154
155 pub async fn subscribe_forwards(&self, identity_key: Key) -> Subscription<Forward> {
157 let mut watches = self.forward_watches.write().await;
158
159 let tx = watches.entry(identity_key).or_insert_with(|| {
160 let (tx, _) = broadcast::channel(100);
161 tx
162 });
163
164 Subscription::new(tx.subscribe())
165 }
166
167 pub async fn publish_forward_for(&self, identity_key: Key, forward: Forward) -> Result<()> {
169 let watches = self.forward_watches.read().await;
170
171 if let Some(tx) = watches.get(&identity_key) {
172 let _ = tx.send(forward); }
174
175 Ok(())
176 }
177
178 pub async fn cleanup_expired(&self) {
180 let mut dht_watches = self.dht_watches.write().await;
181 dht_watches.retain(|_, tx| tx.receiver_count() > 0);
182
183 let mut forward_watches = self.forward_watches.write().await;
184 forward_watches.retain(|_, tx| tx.receiver_count() > 0);
185 }
186}
187
188impl Default for EventBus {
189 fn default() -> Self {
190 Self::new()
191 }
192}
193
194static GLOBAL_BUS: once_cell::sync::Lazy<EventBus> = once_cell::sync::Lazy::new(EventBus::new);
196
197pub fn global_bus() -> &'static EventBus {
199 &GLOBAL_BUS
200}
201
202pub fn subscribe_topology() -> Subscription<TopologyEvent> {
204 global_bus().subscribe_topology()
205}
206
207pub async fn dht_watch(key: Key) -> Subscription<Bytes> {
209 global_bus().subscribe_dht_key(key).await
210}
211
212pub async fn device_subscribe(identity_key: Key) -> Subscription<Forward> {
214 global_bus().subscribe_forwards(identity_key).await
215}
216
217#[cfg(test)]
218mod tests {
219 use super::*;
220
221 #[tokio::test]
222 async fn test_topology_events() {
223 let bus = EventBus::new();
224 let mut sub = bus.subscribe_topology();
225
226 let event = TopologyEvent::PeerJoined {
227 peer_id: vec![1, 2, 3],
228 address: "127.0.0.1:9000".to_string(),
229 };
230
231 bus.publish_topology(event.clone()).await.unwrap();
232
233 let received = sub.recv().await.unwrap();
234 match received {
235 TopologyEvent::PeerJoined { peer_id, address } => {
236 assert_eq!(peer_id, vec![1, 2, 3]);
237 assert_eq!(address, "127.0.0.1:9000");
238 }
239 _ => panic!("Wrong event type"),
240 }
241 }
242
243 #[tokio::test]
244 async fn test_dht_watch() {
245 let bus = EventBus::new();
246 let key = Key::new([42u8; 32]);
247
248 let mut sub = bus.subscribe_dht_key(key.clone()).await;
249
250 let value = Bytes::from_static(&[1, 2, 3, 4]);
251 bus.publish_dht_update(key, value.clone()).await.unwrap();
252
253 let received = sub.recv().await.unwrap();
254 assert_eq!(received, value);
255 }
256
257 #[tokio::test]
258 async fn test_forward_events() {
259 let bus = EventBus::new();
260 let identity_key = Key::new([99u8; 32]);
261
262 let mut sub = bus.subscribe_forwards(identity_key.clone()).await;
263
264 let fwd = Forward {
265 proto: "ant-quic".to_string(),
266 addr: "quic://example.com:9000".to_string(),
267 exp: 1234567890,
268 };
269
270 bus.publish_forward_for(identity_key.clone(), fwd.clone())
271 .await
272 .unwrap();
273
274 let received = sub.recv().await.unwrap();
275 assert_eq!(received.proto, "ant-quic");
276 assert_eq!(received.addr, "quic://example.com:9000");
277 }
278
279 #[tokio::test]
280 async fn test_cleanup() {
281 let bus = EventBus::new();
282 let key = Key::new([1u8; 32]);
283
284 {
286 let _sub = bus.subscribe_dht_key(key.clone()).await;
287 }
288
289 assert_eq!(bus.dht_watches.read().await.len(), 1);
291
292 bus.cleanup_expired().await;
294
295 assert_eq!(bus.dht_watches.read().await.len(), 0);
297 }
298}