Skip to main content

saorsa_core/events/
mod.rs

1// Copyright 2024 Saorsa Labs Limited
2//
3// This software is dual-licensed under:
4// - GNU Affero General Public License v3.0 or later (AGPL-3.0-or-later)
5// - Commercial License
6//
7// For AGPL-3.0 license, see LICENSE-AGPL-3.0
8// For commercial licensing, contact: david@saorsalabs.com
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under these licenses is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
14//! Async event bus for watches and topology changes.
15//!
16//! This module provides subscription-based event handling for
17//! state changes throughout the system.
18
19use crate::fwid::Key;
20use crate::types::Forward;
21use anyhow::Result;
22use bytes::Bytes;
23use serde::{Deserialize, Serialize};
24use std::collections::HashMap;
25use std::sync::Arc;
26use tokio::sync::{RwLock, broadcast};
27
28/// A subscription handle for receiving events
29pub struct Subscription<T> {
30    receiver: broadcast::Receiver<T>,
31}
32
33impl<T: Clone> Subscription<T> {
34    /// Create a new subscription with a receiver
35    fn new(receiver: broadcast::Receiver<T>) -> Self {
36        Self { receiver }
37    }
38
39    /// Receive the next event
40    pub async fn recv(&mut self) -> Result<T> {
41        self.receiver
42            .recv()
43            .await
44            .map_err(|e| anyhow::anyhow!("Subscription error: {}", e))
45    }
46
47    /// Try to receive without blocking
48    pub fn try_recv(&mut self) -> Result<T> {
49        self.receiver
50            .try_recv()
51            .map_err(|e| anyhow::anyhow!("Subscription error: {}", e))
52    }
53}
54
55/// Network topology change events
56#[derive(Debug, Clone, Serialize, Deserialize)]
57pub enum TopologyEvent {
58    /// A new peer joined the network
59    PeerJoined { peer_id: Vec<u8>, address: String },
60    /// A peer left the network
61    PeerLeft { peer_id: Vec<u8>, reason: String },
62    /// Network partition detected
63    PartitionDetected {
64        partition_id: u64,
65        affected_peers: Vec<Vec<u8>>,
66    },
67    /// Network partition healed
68    PartitionHealed { partition_id: u64 },
69    /// Routing table updated
70    RoutingTableUpdated {
71        added: Vec<Vec<u8>>,
72        removed: Vec<Vec<u8>>,
73    },
74}
75
76/// DHT key watch events
77#[derive(Debug, Clone, Serialize, Deserialize)]
78pub enum DhtWatchEvent {
79    /// Value stored at key
80    ValueStored { key: Key, value: Vec<u8> },
81    /// Value updated at key
82    ValueUpdated {
83        key: Key,
84        old_value: Vec<u8>,
85        new_value: Vec<u8>,
86    },
87    /// Value deleted at key
88    ValueDeleted { key: Key },
89    /// Key expired
90    KeyExpired { key: Key },
91}
92
93// ForwardEvent removed; we publish Forward values scoped by identity key
94
95/// The main event bus for the system
96pub struct EventBus {
97    /// Topology event broadcaster
98    topology_tx: broadcast::Sender<TopologyEvent>,
99
100    /// DHT watch broadcasters by key
101    dht_watches: Arc<RwLock<HashMap<Key, broadcast::Sender<Bytes>>>>,
102
103    /// Forward event broadcasters by identity key
104    forward_watches: Arc<RwLock<HashMap<Key, broadcast::Sender<Forward>>>>,
105}
106
107impl EventBus {
108    /// Create a new event bus
109    pub fn new() -> Self {
110        let (topology_tx, _) = broadcast::channel(1000);
111
112        Self {
113            topology_tx,
114            dht_watches: Arc::new(RwLock::new(HashMap::new())),
115            forward_watches: Arc::new(RwLock::new(HashMap::new())),
116        }
117    }
118
119    /// Subscribe to topology events
120    pub fn subscribe_topology(&self) -> Subscription<TopologyEvent> {
121        Subscription::new(self.topology_tx.subscribe())
122    }
123
124    /// Publish a topology event
125    pub async fn publish_topology(&self, event: TopologyEvent) -> Result<()> {
126        self.topology_tx
127            .send(event)
128            .map_err(|_| anyhow::anyhow!("No topology subscribers"))?;
129        Ok(())
130    }
131
132    /// Subscribe to DHT key watches
133    pub async fn subscribe_dht_key(&self, key: Key) -> Subscription<Bytes> {
134        let mut watches = self.dht_watches.write().await;
135
136        let tx = watches.entry(key).or_insert_with(|| {
137            let (tx, _) = broadcast::channel(100);
138            tx
139        });
140
141        Subscription::new(tx.subscribe())
142    }
143
144    /// Publish a DHT key update
145    pub async fn publish_dht_update(&self, key: Key, value: Bytes) -> Result<()> {
146        let watches = self.dht_watches.read().await;
147
148        if let Some(tx) = watches.get(&key) {
149            let _ = tx.send(value); // Ignore if no subscribers
150        }
151
152        Ok(())
153    }
154
155    /// Subscribe to forward announcements for an identity
156    pub async fn subscribe_forwards(&self, identity_key: Key) -> Subscription<Forward> {
157        let mut watches = self.forward_watches.write().await;
158
159        let tx = watches.entry(identity_key).or_insert_with(|| {
160            let (tx, _) = broadcast::channel(100);
161            tx
162        });
163
164        Subscription::new(tx.subscribe())
165    }
166
167    /// Publish a forward announcement scoped to identity
168    pub async fn publish_forward_for(&self, identity_key: Key, forward: Forward) -> Result<()> {
169        let watches = self.forward_watches.read().await;
170
171        if let Some(tx) = watches.get(&identity_key) {
172            let _ = tx.send(forward); // Ignore if no subscribers
173        }
174
175        Ok(())
176    }
177
178    /// Clean up expired subscriptions
179    pub async fn cleanup_expired(&self) {
180        let mut dht_watches = self.dht_watches.write().await;
181        dht_watches.retain(|_, tx| tx.receiver_count() > 0);
182
183        let mut forward_watches = self.forward_watches.write().await;
184        forward_watches.retain(|_, tx| tx.receiver_count() > 0);
185    }
186}
187
188impl Default for EventBus {
189    fn default() -> Self {
190        Self::new()
191    }
192}
193
194/// Global event bus instance (for convenience)
195static GLOBAL_BUS: once_cell::sync::Lazy<EventBus> = once_cell::sync::Lazy::new(EventBus::new);
196
197/// Get the global event bus
198pub fn global_bus() -> &'static EventBus {
199    &GLOBAL_BUS
200}
201
202/// Helper function to subscribe to topology events
203pub fn subscribe_topology() -> Subscription<TopologyEvent> {
204    global_bus().subscribe_topology()
205}
206
207/// Helper function to subscribe to DHT key
208pub async fn dht_watch(key: Key) -> Subscription<Bytes> {
209    global_bus().subscribe_dht_key(key).await
210}
211
212/// Helper function to subscribe to device forwards
213pub async fn device_subscribe(identity_key: Key) -> Subscription<Forward> {
214    global_bus().subscribe_forwards(identity_key).await
215}
216
217#[cfg(test)]
218mod tests {
219    use super::*;
220
221    #[tokio::test]
222    async fn test_topology_events() {
223        let bus = EventBus::new();
224        let mut sub = bus.subscribe_topology();
225
226        let event = TopologyEvent::PeerJoined {
227            peer_id: vec![1, 2, 3],
228            address: "127.0.0.1:9000".to_string(),
229        };
230
231        bus.publish_topology(event.clone()).await.unwrap();
232
233        let received = sub.recv().await.unwrap();
234        match received {
235            TopologyEvent::PeerJoined { peer_id, address } => {
236                assert_eq!(peer_id, vec![1, 2, 3]);
237                assert_eq!(address, "127.0.0.1:9000");
238            }
239            _ => panic!("Wrong event type"),
240        }
241    }
242
243    #[tokio::test]
244    async fn test_dht_watch() {
245        let bus = EventBus::new();
246        let key = Key::new([42u8; 32]);
247
248        let mut sub = bus.subscribe_dht_key(key.clone()).await;
249
250        let value = Bytes::from_static(&[1, 2, 3, 4]);
251        bus.publish_dht_update(key, value.clone()).await.unwrap();
252
253        let received = sub.recv().await.unwrap();
254        assert_eq!(received, value);
255    }
256
257    #[tokio::test]
258    async fn test_forward_events() {
259        let bus = EventBus::new();
260        let identity_key = Key::new([99u8; 32]);
261
262        let mut sub = bus.subscribe_forwards(identity_key.clone()).await;
263
264        let fwd = Forward {
265            proto: "ant-quic".to_string(),
266            addr: "quic://example.com:9000".to_string(),
267            exp: 1234567890,
268        };
269
270        bus.publish_forward_for(identity_key.clone(), fwd.clone())
271            .await
272            .unwrap();
273
274        let received = sub.recv().await.unwrap();
275        assert_eq!(received.proto, "ant-quic");
276        assert_eq!(received.addr, "quic://example.com:9000");
277    }
278
279    #[tokio::test]
280    async fn test_cleanup() {
281        let bus = EventBus::new();
282        let key = Key::new([1u8; 32]);
283
284        // Create subscription then drop it
285        {
286            let _sub = bus.subscribe_dht_key(key.clone()).await;
287        }
288
289        // Check that watch exists
290        assert_eq!(bus.dht_watches.read().await.len(), 1);
291
292        // Clean up
293        bus.cleanup_expired().await;
294
295        // Watch should be removed
296        assert_eq!(bus.dht_watches.read().await.len(), 0);
297    }
298}