tako/
signals.rs

1//! In-process signal arbiter and dispatch system.
2//!
3//! This module defines a small abstraction for named signals that can be emitted
4//! and handled within a Tako application. It is intended for cross-cutting
5//! concerns such as metrics, logging hooks, or custom application events.
6
7use std::sync::atomic::{AtomicUsize, Ordering};
8use std::{any::Any, collections::HashMap, sync::Arc};
9
10use dashmap::DashMap;
11use futures_util::future::{BoxFuture, join_all};
12use once_cell::sync::Lazy;
13use tokio::sync::{broadcast, mpsc};
14use tokio::time::{Duration, timeout};
15
16const DEFAULT_BROADCAST_CAPACITY: usize = 64;
17static GLOBAL_BROADCAST_CAPACITY: AtomicUsize = AtomicUsize::new(DEFAULT_BROADCAST_CAPACITY);
18
19/// Well-known signal identifiers for common lifecycle and request events.
20pub mod ids {
21  pub const SERVER_STARTED: &str = "server.started";
22  pub const SERVER_STOPPED: &str = "server.stopped";
23  pub const CONNECTION_OPENED: &str = "connection.opened";
24  pub const CONNECTION_CLOSED: &str = "connection.closed";
25  pub const REQUEST_STARTED: &str = "request.started";
26  pub const REQUEST_COMPLETED: &str = "request.completed";
27  pub const ROUTER_HOT_RELOAD: &str = "router.hot_reload";
28  pub const RPC_ERROR: &str = "rpc.error";
29  pub const ROUTE_REQUEST_STARTED: &str = "route.request.started";
30  pub const ROUTE_REQUEST_COMPLETED: &str = "route.request.completed";
31}
32
33/// A signal emitted through the arbiter.
34///
35/// Signals are identified by an arbitrary string and can carry a map of
36/// metadata. Callers are free to define their own conventions for ids and
37/// fields.
38#[derive(Clone, Debug, Default)]
39pub struct Signal {
40  /// Identifier of the signal, for example "request.started" or "metrics.tick".
41  pub id: String,
42  /// Optional metadata payload carried with the signal.
43  pub metadata: HashMap<String, String>,
44}
45
46impl Signal {
47  /// Creates a new signal with the given id and empty metadata.
48  pub fn new(id: impl Into<String>) -> Self {
49    Self {
50      id: id.into(),
51      metadata: HashMap::new(),
52    }
53  }
54
55  /// Creates a new signal with initial metadata.
56  pub fn with_metadata(id: impl Into<String>, metadata: HashMap<String, String>) -> Self {
57    Self {
58      id: id.into(),
59      metadata,
60    }
61  }
62
63  /// Creates a signal from a typed payload implementing `SignalPayload`.
64  pub fn from_payload<P: SignalPayload>(payload: &P) -> Self {
65    Self {
66      id: payload.id().to_string(),
67      metadata: payload.to_metadata(),
68    }
69  }
70}
71
72/// Trait for types that can be converted into a `Signal`.
73pub trait SignalPayload {
74  /// The canonical id for this kind of signal, e.g. "request.completed".
75  fn id(&self) -> &'static str;
76
77  /// Serializes the payload into the metadata map.
78  fn to_metadata(&self) -> HashMap<String, String>;
79}
80
81/// Boxed async signal handler.
82pub type SignalHandler = Arc<dyn Fn(Signal) -> BoxFuture<'static, ()> + Send + Sync>;
83
84/// Boxed typed RPC handler used by the signal arbiter.
85pub type RpcHandler = Arc<
86  dyn Fn(Arc<dyn Any + Send + Sync>) -> BoxFuture<'static, Arc<dyn Any + Send + Sync>>
87    + Send
88    + Sync,
89>;
90
91/// Exporter callback invoked for every emitted signal.
92pub type SignalExporter = Arc<dyn Fn(&Signal) + Send + Sync>;
93
94/// Simple stream type returned by filtered subscriptions.
95pub type SignalStream = mpsc::UnboundedReceiver<Signal>;
96
97#[derive(Default)]
98struct Inner {
99  handlers: DashMap<String, Vec<SignalHandler>>,
100  topics: DashMap<String, broadcast::Sender<Signal>>,
101  rpc: DashMap<String, RpcHandler>,
102  exporters: DashMap<u64, SignalExporter>,
103}
104
105/// Shared arbiter used to register and dispatch named signals.
106#[derive(Clone, Default)]
107pub struct SignalArbiter {
108  inner: Arc<Inner>,
109}
110
111/// Global application-level signal arbiter.
112static APP_SIGNAL_ARBITER: Lazy<SignalArbiter> = Lazy::new(SignalArbiter::new);
113
114/// Returns a reference to the global application-level signal arbiter.
115pub fn app_signals() -> &'static SignalArbiter {
116  &APP_SIGNAL_ARBITER
117}
118
119/// Returns the global application-level signal arbiter.
120pub fn app_events() -> &'static SignalArbiter {
121  app_signals()
122}
123
124/// Error type for typed RPC calls.
125#[derive(Debug, Clone)]
126pub enum RpcError {
127  NoHandler,
128  TypeMismatch,
129}
130
131/// Result type for RPC calls with explicit error reporting.
132pub type RpcResult<T> = Result<T, RpcError>;
133
134/// Error type for RPC calls with timeout support.
135#[derive(Debug, Clone)]
136pub enum RpcTimeoutError {
137  Timeout,
138  Rpc(RpcError),
139}
140
141impl SignalArbiter {
142  /// Creates a new, empty signal arbiter.
143  pub fn new() -> Self {
144    Self::default()
145  }
146
147  /// Sets the global broadcast capacity used for topic channels.
148  ///
149  /// This affects all newly created topics across all arbiters.
150  pub fn set_global_broadcast_capacity(capacity: usize) {
151    let cap = capacity.max(1);
152    GLOBAL_BROADCAST_CAPACITY.store(cap, Ordering::SeqCst);
153  }
154
155  /// Returns the current global broadcast capacity.
156  pub fn global_broadcast_capacity() -> usize {
157    GLOBAL_BROADCAST_CAPACITY.load(Ordering::SeqCst)
158  }
159
160  /// Returns (and lazily initializes) the broadcast sender for a signal id.
161  pub(crate) fn topic_sender(&self, id: &str) -> broadcast::Sender<Signal> {
162    if let Some(existing) = self.inner.topics.get(id) {
163      existing.clone()
164    } else {
165      let cap = GLOBAL_BROADCAST_CAPACITY.load(Ordering::SeqCst);
166      let (tx, _rx) = broadcast::channel(cap);
167      let entry = self.inner.topics.entry(id.to_string()).or_insert(tx);
168      entry.clone()
169    }
170  }
171
172  /// Registers a handler for the given signal id.
173  ///
174  /// Handlers are invoked in registration order whenever a matching signal is emitted.
175  pub fn on<F, Fut>(&self, id: impl Into<String>, handler: F)
176  where
177    F: Fn(Signal) -> Fut + Send + Sync + 'static,
178    Fut: std::future::Future<Output = ()> + Send + 'static,
179  {
180    let id = id.into();
181    let handler: SignalHandler = Arc::new(move |signal: Signal| {
182      let fut = handler(signal);
183      Box::pin(async move { fut.await })
184    });
185
186    self
187      .inner
188      .handlers
189      .entry(id)
190      .or_insert_with(Vec::new)
191      .push(handler);
192  }
193
194  /// Subscribes to a broadcast channel for the given signal id.
195  ///
196  /// This is useful for long-lived listeners such as metrics collectors,
197  /// background workers, plugins, or middleware driven tasks.
198  pub fn subscribe(&self, id: impl AsRef<str>) -> broadcast::Receiver<Signal> {
199    let id_str = id.as_ref();
200    let sender = self.topic_sender(id_str);
201    sender.subscribe()
202  }
203
204  /// Subscribes to all signals whose id starts with the given prefix.
205  ///
206  /// For example, `subscribe_prefix("request.")` will receive
207  /// `request.started`, `request.completed`, etc.
208  pub fn subscribe_prefix(&self, prefix: impl AsRef<str>) -> broadcast::Receiver<Signal> {
209    let mut key = prefix.as_ref().to_string();
210    if !key.ends_with('*') {
211      key.push('*');
212    }
213    let sender = self.topic_sender(&key);
214    sender.subscribe()
215  }
216
217  /// Subscribes to all signals regardless of their id.
218  ///
219  /// This is a special variant that receives every emitted signal.
220  /// Internally uses a wildcard prefix matching (empty prefix = all signals).
221  pub fn subscribe_all(&self) -> broadcast::Receiver<Signal> {
222    self.subscribe_prefix("")
223  }
224
225  /// Broadcasts a signal to all subscribers without awaiting handler completion.
226  pub(crate) fn broadcast(&self, signal: Signal) {
227    // Exact id subscribers
228    if let Some(sender) = self.inner.topics.get(&signal.id) {
229      let _ = sender.send(signal.clone());
230    }
231
232    // Prefix subscribers: keys ending with '*'
233    for entry in self.inner.topics.iter() {
234      let key = entry.key();
235      if let Some(prefix) = key.strip_suffix('*') {
236        if signal.id.starts_with(prefix) {
237          let _ = entry.value().send(signal.clone());
238        }
239      }
240    }
241  }
242
243  /// Subscribes using a filter function on top of an id-based subscription.
244  ///
245  /// This spawns a background task that forwards only matching signals into
246  /// an unbounded channel, which is returned as a `SignalStream`.
247  pub fn subscribe_filtered<F>(&self, id: impl AsRef<str>, filter: F) -> SignalStream
248  where
249    F: Fn(&Signal) -> bool + Send + Sync + 'static,
250  {
251    let mut rx = self.subscribe(id);
252    let (tx, out_rx) = mpsc::unbounded_channel();
253    let filter = Arc::new(filter);
254
255    tokio::spawn(async move {
256      while let Ok(signal) = rx.recv().await {
257        if filter(&signal) {
258          if tx.send(signal).is_err() {
259            break;
260          }
261        }
262      }
263    });
264
265    out_rx
266  }
267
268  /// Waits for the next occurrence of a signal id (oneshot-style).
269  ///
270  /// This uses the broadcast channel under the hood but resolves on the
271  /// first successfully received signal.
272  pub async fn once(&self, id: impl AsRef<str>) -> Option<Signal> {
273    let mut rx = self.subscribe(id);
274    loop {
275      match rx.recv().await {
276        Ok(sig) => return Some(sig),
277        Err(broadcast::error::RecvError::Lagged(_)) => continue,
278        Err(_) => return None,
279      }
280    }
281  }
282
283  /// Registers a typed RPC handler under the given id.
284  ///
285  /// This allows request/response style interactions over the same arbiter,
286  /// using type-erased storage internally for flexibility.
287  pub fn register_rpc<Req, Res, F, Fut>(&self, id: impl Into<String>, f: F)
288  where
289    Req: Send + Sync + 'static,
290    Res: Send + Sync + 'static,
291    F: Fn(Arc<Req>) -> Fut + Send + Sync + 'static,
292    Fut: std::future::Future<Output = Res> + Send + 'static,
293  {
294    let id_str = id.into();
295    let id_for_panic = id_str.clone();
296    let func = Arc::new(f);
297
298    let handler: RpcHandler = Arc::new(move |raw: Arc<dyn Any + Send + Sync>| {
299      let func = func.clone();
300      let id_for_panic = id_for_panic.clone();
301      Box::pin(async move {
302        let req = raw
303          .downcast::<Req>()
304          .unwrap_or_else(|_| panic!("Signal RPC type mismatch for id: {}", id_for_panic));
305        let res = func(req).await;
306        Arc::new(res) as Arc<dyn Any + Send + Sync>
307      })
308    });
309
310    self.inner.rpc.insert(id_str, handler);
311  }
312
313  /// Calls a typed RPC handler and returns a shared pointer to the response.
314  pub async fn call_rpc_arc<Req, Res>(&self, id: impl AsRef<str>, req: Req) -> Option<Arc<Res>>
315  where
316    Req: Send + Sync + 'static,
317    Res: Send + Sync + 'static,
318  {
319    let id_str = id.as_ref();
320    let entry = self.inner.rpc.get(id_str)?;
321    let handler = entry.clone();
322    drop(entry);
323
324    let raw_req: Arc<dyn Any + Send + Sync> = Arc::new(req);
325    let raw_res = handler(raw_req).await;
326
327    match raw_res.downcast::<Res>() {
328      Ok(res) => Some(res),
329      Err(_) => None,
330    }
331  }
332
333  /// Calls a typed RPC handler and returns an owned response with an error type.
334  pub async fn call_rpc_result<Req, Res>(&self, id: impl AsRef<str>, req: Req) -> RpcResult<Res>
335  where
336    Req: Send + Sync + 'static,
337    Res: Send + Sync + Clone + 'static,
338  {
339    let id_str = id.as_ref();
340    let entry = self.inner.rpc.get(id_str);
341    let entry = match entry {
342      Some(e) => e,
343      None => return Err(RpcError::NoHandler),
344    };
345    let handler = entry.clone();
346    drop(entry);
347
348    let raw_req: Arc<dyn Any + Send + Sync> = Arc::new(req);
349    let raw_res = handler(raw_req).await;
350
351    match raw_res.downcast::<Res>() {
352      Ok(res) => Ok((*res).clone()),
353      Err(_) => Err(RpcError::TypeMismatch),
354    }
355  }
356
357  /// Calls a typed RPC handler and returns an owned response.
358  pub async fn call_rpc<Req, Res>(&self, id: impl AsRef<str>, req: Req) -> Option<Res>
359  where
360    Req: Send + Sync + 'static,
361    Res: Send + Sync + Clone + 'static,
362  {
363    self.call_rpc_result::<Req, Res>(id, req).await.ok()
364  }
365
366  /// Calls a typed RPC handler with a timeout.
367  pub async fn call_rpc_timeout<Req, Res>(
368    &self,
369    id: impl AsRef<str>,
370    req: Req,
371    dur: Duration,
372  ) -> Result<Res, RpcTimeoutError>
373  where
374    Req: Send + Sync + 'static,
375    Res: Send + Sync + Clone + 'static,
376  {
377    match timeout(dur, self.call_rpc_result::<Req, Res>(id, req)).await {
378      Ok(Ok(res)) => Ok(res),
379      Ok(Err(e)) => Err(RpcTimeoutError::Rpc(e)),
380      Err(_) => Err(RpcTimeoutError::Timeout),
381    }
382  }
383
384  /// Emits a signal and awaits all registered handlers.
385  ///
386  /// Handlers run concurrently and this method resolves once all handlers have completed.
387  pub async fn emit(&self, signal: Signal) {
388    // First, broadcast to any subscribers.
389    self.broadcast(signal.clone());
390
391    // Call exporters (non-blocking from the perspective of handlers).
392    for entry in self.inner.exporters.iter() {
393      (entry.value())(&signal);
394    }
395
396    if let Some(entry) = self.inner.handlers.get(&signal.id) {
397      let handlers = entry.clone();
398      drop(entry);
399
400      let futures = handlers.into_iter().map(|handler| {
401        let s = signal.clone();
402        handler(s)
403      });
404
405      let _ = join_all(futures).await;
406    }
407  }
408
409  /// Emits a signal using the global application-level arbiter.
410  pub async fn emit_app(signal: Signal) {
411    app_signals().emit(signal).await;
412  }
413
414  /// Registers a global exporter that is invoked for every emitted signal.
415  ///
416  /// Exporters are merged when routers are merged, similar to handlers.
417  pub fn register_exporter<F>(&self, exporter: F)
418  where
419    F: Fn(&Signal) + Send + Sync + 'static,
420  {
421    // Use the pointer address as a simple, best-effort key.
422    let key = Arc::into_raw(Arc::new(())) as u64;
423    let exporter: SignalExporter = Arc::new(exporter);
424    self.inner.exporters.insert(key, exporter);
425  }
426
427  /// Merges all handlers from `other` into `self`.
428  ///
429  /// This is used by router merging so that signal handlers attached to
430  /// a merged router continue to be active.
431  pub(crate) fn merge_from(&self, other: &SignalArbiter) {
432    for entry in other.inner.handlers.iter() {
433      let id = entry.key().clone();
434      let handlers = entry.value().clone();
435
436      self
437        .inner
438        .handlers
439        .entry(id)
440        .or_insert_with(Vec::new)
441        .extend(handlers);
442    }
443
444    for entry in other.inner.topics.iter() {
445      let id = entry.key().clone();
446      let sender = entry.value().clone();
447      self.inner.topics.entry(id).or_insert(sender);
448    }
449
450    for entry in other.inner.rpc.iter() {
451      let id = entry.key().clone();
452      let handler = entry.value().clone();
453      self.inner.rpc.insert(id, handler);
454    }
455
456    for entry in other.inner.exporters.iter() {
457      let key = entry.key().clone();
458      let exporter = entry.value().clone();
459      self.inner.exporters.insert(key, exporter);
460    }
461  }
462
463  /// Returns a list of known signal ids (exact topics) currently registered.
464  pub fn signal_ids(&self) -> Vec<String> {
465    self
466      .inner
467      .topics
468      .iter()
469      .filter_map(|entry| {
470        let id = entry.key();
471        if id.ends_with('*') {
472          None
473        } else {
474          Some(id.clone())
475        }
476      })
477      .collect()
478  }
479
480  /// Returns a list of known signal prefixes (topics ending with '*').
481  pub fn signal_prefixes(&self) -> Vec<String> {
482    self
483      .inner
484      .topics
485      .iter()
486      .filter_map(|entry| {
487        let id = entry.key();
488        if id.ends_with('*') {
489          Some(id.clone())
490        } else {
491          None
492        }
493      })
494      .collect()
495  }
496
497  /// Returns a list of registered RPC ids.
498  pub fn rpc_ids(&self) -> Vec<String> {
499    self.inner.rpc.iter().map(|e| e.key().clone()).collect()
500  }
501}