Skip to main content

veilid_tools/
event_bus.rs

1//! Event Bus
2
3use super::*;
4use futures_util::stream::{FuturesUnordered, StreamExt};
5use stop_token::future::FutureExt as _;
6
7use std::any::{Any, TypeId};
8
9type AnyEventHandler =
10    Arc<dyn Fn(Arc<dyn Any + Send + Sync>) -> PinBoxFutureStatic<()> + Send + Sync>;
11type SubscriptionId = u64;
12
13#[derive(Debug)]
14pub struct EventBusSubscription {
15    id: SubscriptionId,
16    type_id: TypeId,
17}
18
19struct QueuedEvent {
20    evt: Arc<dyn Any + Send + Sync>,
21    type_name: &'static str,
22}
23
24struct EventBusUnlockedInner {
25    tx: flume::Sender<QueuedEvent>,
26    rx: flume::Receiver<QueuedEvent>,
27    startup_lock: StartupLock,
28}
29
30impl fmt::Debug for EventBusUnlockedInner {
31    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
32        f.debug_struct("EventBusUnlockedInner")
33            .field("tx", &self.tx)
34            .field("rx", &self.rx)
35            .field("startup_lock", &self.startup_lock)
36            .finish()
37    }
38}
39
40struct EventBusInner {
41    handlers: HashMap<TypeId, Vec<(SubscriptionId, AnyEventHandler)>>,
42    next_sub_id: SubscriptionId,
43    free_sub_ids: Vec<SubscriptionId>,
44    stop_source: Option<StopSource>,
45    bus_processor_jh: Option<MustJoinHandle<()>>,
46}
47
48impl fmt::Debug for EventBusInner {
49    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
50        f.debug_struct("EventBusInner")
51            .field("handlers.len", &self.handlers.len())
52            .field("next_sub_id", &self.next_sub_id)
53            .field("free_sub_ids", &self.free_sub_ids)
54            .finish()
55    }
56}
57
58/// Event bus
59///
60/// Asynchronously handles events of arbitrary Any type
61/// by passing them in-order to a set of registered async 'handler' functions.
62/// Handlers are processes in an unordered fashion, but an event is fully handled by all handlers
63/// until the next event in the posted event stream is processed.
64#[derive(Debug, Clone)]
65pub struct EventBus {
66    unlocked_inner: Arc<EventBusUnlockedInner>,
67    inner: Arc<Mutex<EventBusInner>>,
68}
69
70impl Default for EventBus {
71    fn default() -> Self {
72        Self::new()
73    }
74}
75
76impl EventBus {
77    ////////////////////////////////////////////////////////////////////
78    // Public interface
79
80    /// Create a new EventBus
81    #[must_use]
82    pub fn new() -> Self {
83        let (tx, rx) = flume::unbounded();
84        Self {
85            unlocked_inner: Arc::new(EventBusUnlockedInner {
86                tx,
87                rx,
88                startup_lock: StartupLock::new(),
89            }),
90            inner: Arc::new(Mutex::new(Self::new_inner())),
91        }
92    }
93
94    /// Start up the EventBus background processor
95    pub fn startup(&self) -> Result<(), StartupLockAlreadyStartedError> {
96        let guard = self.unlocked_inner.startup_lock.startup()?;
97        {
98            let mut inner = self.inner.lock();
99            let stop_source = StopSource::new();
100            let stop_token = stop_source.token();
101            inner.stop_source = Some(stop_source);
102
103            let bus_processor_jh = spawn(
104                "event bus processor",
105                self.clone().bus_processor(stop_token),
106            );
107            inner.bus_processor_jh = Some(bus_processor_jh);
108        }
109
110        guard.success();
111        Ok(())
112    }
113
114    /// Shut down EventBus background processing
115    /// This unregisters all handlers as well and discards any unprocessed events
116    pub async fn shutdown(&self) {
117        let Ok(guard) = self.unlocked_inner.startup_lock.shutdown().await else {
118            return;
119        };
120
121        let opt_jh = {
122            let mut inner = self.inner.lock();
123            drop(inner.stop_source.take());
124            inner.bus_processor_jh.take()
125        };
126
127        if let Some(jh) = opt_jh {
128            jh.await;
129        }
130
131        *self.inner.lock() = Self::new_inner();
132
133        guard.success();
134    }
135
136    /// Post an event to be processed
137    pub fn post<E: Any + Send + Sync + 'static>(
138        &self,
139        evt: E,
140    ) -> Result<(), StartupLockNotStartedError> {
141        let _guard = self.unlocked_inner.startup_lock.enter()?;
142
143        if let Err(e) = self.unlocked_inner.tx.send(QueuedEvent {
144            evt: Arc::new(evt),
145            type_name: std::any::type_name::<E>(),
146        }) {
147            error!("{}", e);
148        }
149        Ok(())
150    }
151
152    /// Subscribe a handler to handle all posted events of a particular type
153    /// Returns an subscription object that can be used to cancel this specific subscription if desired
154    pub fn subscribe<
155        E: Any + Send + Sync + 'static,
156        F: Fn(Arc<E>) -> PinBoxFutureStatic<()> + Send + Sync + 'static,
157    >(
158        &self,
159        handler: F,
160    ) -> EventBusSubscription {
161        let handler = Arc::new(handler);
162        let type_id = TypeId::of::<E>();
163        let mut inner = self.inner.lock();
164
165        let id = inner.free_sub_ids.pop().unwrap_or_else(|| {
166            let id = inner.next_sub_id;
167            inner.next_sub_id += 1;
168            id
169        });
170
171        inner.handlers.entry(type_id).or_default().push((
172            id,
173            Arc::new(move |any_evt| {
174                let handler = handler.clone();
175                Box::pin(async move {
176                    handler(any_evt.downcast::<E>().unwrap_or_log()).await;
177                })
178            }),
179        ));
180
181        EventBusSubscription { id, type_id }
182    }
183
184    /// Given a subscription object returned from `subscribe`, removes the
185    /// subscription for the EventBus. The handler will no longer be called.
186    pub fn unsubscribe(&self, sub: EventBusSubscription) {
187        let mut inner = self.inner.lock();
188
189        inner.handlers.entry(sub.type_id).and_modify(|e| {
190            let index = e.iter().position(|x| x.0 == sub.id).unwrap_or_log();
191            e.remove(index);
192        });
193
194        inner.free_sub_ids.push(sub.id);
195    }
196
197    /// Returns the number of unprocessed events remaining
198    #[must_use]
199    pub fn len(&self) -> usize {
200        self.unlocked_inner.rx.len()
201    }
202
203    /// Checks if the bus has no events
204    #[must_use]
205    pub fn is_empty(&self) -> bool {
206        self.unlocked_inner.rx.is_empty()
207    }
208
209    ////////////////////////////////////////////////////////////////////
210    // Internal implementation
211
212    fn new_inner() -> EventBusInner {
213        EventBusInner {
214            handlers: HashMap::new(),
215            next_sub_id: 0,
216            free_sub_ids: vec![],
217            stop_source: None,
218            bus_processor_jh: None,
219        }
220    }
221
222    async fn bus_processor(self, stop_token: StopToken) {
223        while let Ok(Ok(qe)) = self
224            .unlocked_inner
225            .rx
226            .recv_async()
227            .timeout_at(stop_token.clone())
228            .await
229        {
230            let Ok(_guard) = self.unlocked_inner.startup_lock.enter() else {
231                break;
232            };
233            let type_id = (qe.evt.as_ref()).type_id();
234            let type_name = qe.type_name;
235
236            let opt_handlers: Option<FuturesUnordered<_>> = {
237                let mut inner = self.inner.lock();
238                match inner.handlers.entry(type_id) {
239                    std::collections::hash_map::Entry::Occupied(entry) => Some(
240                        entry
241                            .get()
242                            .iter()
243                            .map(|(_id, handler)| handler(qe.evt.clone()))
244                            .collect(),
245                    ),
246                    std::collections::hash_map::Entry::Vacant(_) => {
247                        error!("no handlers for event: {}", type_name);
248                        None
249                    }
250                }
251            };
252
253            // Process all handlers for this event simultaneously
254            if let Some(mut handlers) = opt_handlers {
255                while let Ok(Some(_)) = handlers.next().timeout_at(stop_token.clone()).await {}
256            }
257        }
258    }
259}