tako/
signals.rs

1//! In-process signal arbiter and dispatch system.
2//!
3//! This module defines a small abstraction for named signals that can be emitted
4//! and handled within a Tako application. It is intended for cross-cutting
5//! concerns such as metrics, logging hooks, or custom application events.
6
7use crate::types::BuildHasher;
8use std::sync::atomic::{AtomicUsize, Ordering};
9use std::{any::Any, collections::HashMap, sync::Arc};
10
11use dashmap::DashMap;
12use futures_util::future::{BoxFuture, join_all};
13use once_cell::sync::Lazy;
14use tokio::sync::{broadcast, mpsc};
15use tokio::time::{Duration, timeout};
16
17const DEFAULT_BROADCAST_CAPACITY: usize = 64;
18static GLOBAL_BROADCAST_CAPACITY: AtomicUsize = AtomicUsize::new(DEFAULT_BROADCAST_CAPACITY);
19
20/// Well-known signal identifiers for common lifecycle and request events.
21pub mod ids {
22  pub const SERVER_STARTED: &str = "server.started";
23  pub const SERVER_STOPPED: &str = "server.stopped";
24  pub const CONNECTION_OPENED: &str = "connection.opened";
25  pub const CONNECTION_CLOSED: &str = "connection.closed";
26  pub const REQUEST_STARTED: &str = "request.started";
27  pub const REQUEST_COMPLETED: &str = "request.completed";
28  pub const ROUTER_HOT_RELOAD: &str = "router.hot_reload";
29  pub const RPC_ERROR: &str = "rpc.error";
30  pub const ROUTE_REQUEST_STARTED: &str = "route.request.started";
31  pub const ROUTE_REQUEST_COMPLETED: &str = "route.request.completed";
32}
33
34/// A signal emitted through the arbiter.
35///
36/// Signals are identified by an arbitrary string and can carry a map of
37/// metadata. Callers are free to define their own conventions for ids and
38/// fields.
39#[derive(Clone, Debug, Default)]
40pub struct Signal {
41  /// Identifier of the signal, for example "request.started" or "metrics.tick".
42  pub id: String,
43  /// Optional metadata payload carried with the signal.
44  pub metadata: HashMap<String, String, BuildHasher>,
45}
46
47impl Signal {
48  /// Creates a new signal with the given id and empty metadata.
49  pub fn new(id: impl Into<String>) -> Self {
50    Self {
51      id: id.into(),
52      metadata: HashMap::with_hasher(BuildHasher::default()),
53    }
54  }
55
56  /// Creates a new signal with initial metadata.
57  pub fn with_metadata(
58    id: impl Into<String>,
59    metadata: HashMap<String, String, BuildHasher>,
60  ) -> Self {
61    Self {
62      id: id.into(),
63      metadata,
64    }
65  }
66
67  /// Creates a signal from a typed payload implementing `SignalPayload`.
68  pub fn from_payload<P: SignalPayload>(payload: &P) -> Self {
69    Self {
70      id: payload.id().to_string(),
71      metadata: payload.to_metadata(),
72    }
73  }
74}
75
76/// Trait for types that can be converted into a `Signal`.
77pub trait SignalPayload {
78  /// The canonical id for this kind of signal, e.g. "request.completed".
79  fn id(&self) -> &'static str;
80
81  /// Serializes the payload into the metadata map.
82  fn to_metadata(&self) -> HashMap<String, String, BuildHasher>;
83}
84
85/// Boxed async signal handler.
86pub type SignalHandler = Arc<dyn Fn(Signal) -> BoxFuture<'static, ()> + Send + Sync>;
87
88/// Boxed typed RPC handler used by the signal arbiter.
89pub type RpcHandler = Arc<
90  dyn Fn(Arc<dyn Any + Send + Sync>) -> BoxFuture<'static, Arc<dyn Any + Send + Sync>>
91    + Send
92    + Sync,
93>;
94
95/// Exporter callback invoked for every emitted signal.
96pub type SignalExporter = Arc<dyn Fn(&Signal) + Send + Sync>;
97
98/// Simple stream type returned by filtered subscriptions.
99pub type SignalStream = mpsc::UnboundedReceiver<Signal>;
100
101#[derive(Default)]
102struct Inner {
103  handlers: DashMap<String, Vec<SignalHandler>>,
104  topics: DashMap<String, broadcast::Sender<Signal>>,
105  rpc: DashMap<String, RpcHandler>,
106  exporters: DashMap<u64, SignalExporter>,
107}
108
109/// Shared arbiter used to register and dispatch named signals.
110#[derive(Clone, Default)]
111pub struct SignalArbiter {
112  inner: Arc<Inner>,
113}
114
115/// Global application-level signal arbiter.
116static APP_SIGNAL_ARBITER: Lazy<SignalArbiter> = Lazy::new(SignalArbiter::new);
117
118/// Returns a reference to the global application-level signal arbiter.
119pub fn app_signals() -> &'static SignalArbiter {
120  &APP_SIGNAL_ARBITER
121}
122
123/// Returns the global application-level signal arbiter.
124pub fn app_events() -> &'static SignalArbiter {
125  app_signals()
126}
127
128/// Error type for typed RPC calls.
129#[derive(Debug, Clone)]
130pub enum RpcError {
131  NoHandler,
132  TypeMismatch,
133}
134
135/// Result type for RPC calls with explicit error reporting.
136pub type RpcResult<T> = Result<T, RpcError>;
137
138/// Error type for RPC calls with timeout support.
139#[derive(Debug, Clone)]
140pub enum RpcTimeoutError {
141  Timeout,
142  Rpc(RpcError),
143}
144
145impl SignalArbiter {
146  /// Creates a new, empty signal arbiter.
147  pub fn new() -> Self {
148    Self::default()
149  }
150
151  /// Sets the global broadcast capacity used for topic channels.
152  ///
153  /// This affects all newly created topics across all arbiters.
154  pub fn set_global_broadcast_capacity(capacity: usize) {
155    let cap = capacity.max(1);
156    GLOBAL_BROADCAST_CAPACITY.store(cap, Ordering::SeqCst);
157  }
158
159  /// Returns the current global broadcast capacity.
160  pub fn global_broadcast_capacity() -> usize {
161    GLOBAL_BROADCAST_CAPACITY.load(Ordering::SeqCst)
162  }
163
164  /// Returns (and lazily initializes) the broadcast sender for a signal id.
165  pub(crate) fn topic_sender(&self, id: &str) -> broadcast::Sender<Signal> {
166    if let Some(existing) = self.inner.topics.get(id) {
167      existing.clone()
168    } else {
169      let cap = GLOBAL_BROADCAST_CAPACITY.load(Ordering::SeqCst);
170      let (tx, _rx) = broadcast::channel(cap);
171      let entry = self.inner.topics.entry(id.to_string()).or_insert(tx);
172      entry.clone()
173    }
174  }
175
176  /// Registers a handler for the given signal id.
177  ///
178  /// Handlers are invoked in registration order whenever a matching signal is emitted.
179  pub fn on<F, Fut>(&self, id: impl Into<String>, handler: F)
180  where
181    F: Fn(Signal) -> Fut + Send + Sync + 'static,
182    Fut: std::future::Future<Output = ()> + Send + 'static,
183  {
184    let id = id.into();
185    let handler: SignalHandler = Arc::new(move |signal: Signal| {
186      let fut = handler(signal);
187      Box::pin(async move { fut.await })
188    });
189
190    self
191      .inner
192      .handlers
193      .entry(id)
194      .or_insert_with(Vec::new)
195      .push(handler);
196  }
197
198  /// Subscribes to a broadcast channel for the given signal id.
199  ///
200  /// This is useful for long-lived listeners such as metrics collectors,
201  /// background workers, plugins, or middleware driven tasks.
202  pub fn subscribe(&self, id: impl AsRef<str>) -> broadcast::Receiver<Signal> {
203    let id_str = id.as_ref();
204    let sender = self.topic_sender(id_str);
205    sender.subscribe()
206  }
207
208  /// Subscribes to all signals whose id starts with the given prefix.
209  ///
210  /// For example, `subscribe_prefix("request.")` will receive
211  /// `request.started`, `request.completed`, etc.
212  pub fn subscribe_prefix(&self, prefix: impl AsRef<str>) -> broadcast::Receiver<Signal> {
213    let mut key = prefix.as_ref().to_string();
214    if !key.ends_with('*') {
215      key.push('*');
216    }
217    let sender = self.topic_sender(&key);
218    sender.subscribe()
219  }
220
221  /// Subscribes to all signals regardless of their id.
222  ///
223  /// This is a special variant that receives every emitted signal.
224  /// Internally uses a wildcard prefix matching (empty prefix = all signals).
225  pub fn subscribe_all(&self) -> broadcast::Receiver<Signal> {
226    self.subscribe_prefix("")
227  }
228
229  /// Broadcasts a signal to all subscribers without awaiting handler completion.
230  pub(crate) fn broadcast(&self, signal: Signal) {
231    // Exact id subscribers
232    if let Some(sender) = self.inner.topics.get(&signal.id) {
233      let _ = sender.send(signal.clone());
234    }
235
236    // Prefix subscribers: keys ending with '*'
237    for entry in self.inner.topics.iter() {
238      let key = entry.key();
239      if let Some(prefix) = key.strip_suffix('*') {
240        if signal.id.starts_with(prefix) {
241          let _ = entry.value().send(signal.clone());
242        }
243      }
244    }
245  }
246
247  /// Subscribes using a filter function on top of an id-based subscription.
248  ///
249  /// This spawns a background task that forwards only matching signals into
250  /// an unbounded channel, which is returned as a `SignalStream`.
251  pub fn subscribe_filtered<F>(&self, id: impl AsRef<str>, filter: F) -> SignalStream
252  where
253    F: Fn(&Signal) -> bool + Send + Sync + 'static,
254  {
255    let mut rx = self.subscribe(id);
256    let (tx, out_rx) = mpsc::unbounded_channel();
257    let filter = Arc::new(filter);
258
259    tokio::spawn(async move {
260      while let Ok(signal) = rx.recv().await {
261        if filter(&signal) {
262          if tx.send(signal).is_err() {
263            break;
264          }
265        }
266      }
267    });
268
269    out_rx
270  }
271
272  /// Waits for the next occurrence of a signal id (oneshot-style).
273  ///
274  /// This uses the broadcast channel under the hood but resolves on the
275  /// first successfully received signal.
276  pub async fn once(&self, id: impl AsRef<str>) -> Option<Signal> {
277    let mut rx = self.subscribe(id);
278    loop {
279      match rx.recv().await {
280        Ok(sig) => return Some(sig),
281        Err(broadcast::error::RecvError::Lagged(_)) => continue,
282        Err(_) => return None,
283      }
284    }
285  }
286
287  /// Registers a typed RPC handler under the given id.
288  ///
289  /// This allows request/response style interactions over the same arbiter,
290  /// using type-erased storage internally for flexibility.
291  pub fn register_rpc<Req, Res, F, Fut>(&self, id: impl Into<String>, f: F)
292  where
293    Req: Send + Sync + 'static,
294    Res: Send + Sync + 'static,
295    F: Fn(Arc<Req>) -> Fut + Send + Sync + 'static,
296    Fut: std::future::Future<Output = Res> + Send + 'static,
297  {
298    let id_str = id.into();
299    let id_for_panic = id_str.clone();
300    let func = Arc::new(f);
301
302    let handler: RpcHandler = Arc::new(move |raw: Arc<dyn Any + Send + Sync>| {
303      let func = func.clone();
304      let id_for_panic = id_for_panic.clone();
305      Box::pin(async move {
306        let req = raw
307          .downcast::<Req>()
308          .unwrap_or_else(|_| panic!("Signal RPC type mismatch for id: {}", id_for_panic));
309        let res = func(req).await;
310        Arc::new(res) as Arc<dyn Any + Send + Sync>
311      })
312    });
313
314    self.inner.rpc.insert(id_str, handler);
315  }
316
317  /// Calls a typed RPC handler and returns a shared pointer to the response.
318  pub async fn call_rpc_arc<Req, Res>(&self, id: impl AsRef<str>, req: Req) -> Option<Arc<Res>>
319  where
320    Req: Send + Sync + 'static,
321    Res: Send + Sync + 'static,
322  {
323    let id_str = id.as_ref();
324    let entry = self.inner.rpc.get(id_str)?;
325    let handler = entry.clone();
326    drop(entry);
327
328    let raw_req: Arc<dyn Any + Send + Sync> = Arc::new(req);
329    let raw_res = handler(raw_req).await;
330
331    match raw_res.downcast::<Res>() {
332      Ok(res) => Some(res),
333      Err(_) => None,
334    }
335  }
336
337  /// Calls a typed RPC handler and returns an owned response with an error type.
338  pub async fn call_rpc_result<Req, Res>(&self, id: impl AsRef<str>, req: Req) -> RpcResult<Res>
339  where
340    Req: Send + Sync + 'static,
341    Res: Send + Sync + Clone + 'static,
342  {
343    let id_str = id.as_ref();
344    let entry = self.inner.rpc.get(id_str);
345    let entry = match entry {
346      Some(e) => e,
347      None => return Err(RpcError::NoHandler),
348    };
349    let handler = entry.clone();
350    drop(entry);
351
352    let raw_req: Arc<dyn Any + Send + Sync> = Arc::new(req);
353    let raw_res = handler(raw_req).await;
354
355    match raw_res.downcast::<Res>() {
356      Ok(res) => Ok((*res).clone()),
357      Err(_) => Err(RpcError::TypeMismatch),
358    }
359  }
360
361  /// Calls a typed RPC handler and returns an owned response.
362  pub async fn call_rpc<Req, Res>(&self, id: impl AsRef<str>, req: Req) -> Option<Res>
363  where
364    Req: Send + Sync + 'static,
365    Res: Send + Sync + Clone + 'static,
366  {
367    self.call_rpc_result::<Req, Res>(id, req).await.ok()
368  }
369
370  /// Calls a typed RPC handler with a timeout.
371  pub async fn call_rpc_timeout<Req, Res>(
372    &self,
373    id: impl AsRef<str>,
374    req: Req,
375    dur: Duration,
376  ) -> Result<Res, RpcTimeoutError>
377  where
378    Req: Send + Sync + 'static,
379    Res: Send + Sync + Clone + 'static,
380  {
381    match timeout(dur, self.call_rpc_result::<Req, Res>(id, req)).await {
382      Ok(Ok(res)) => Ok(res),
383      Ok(Err(e)) => Err(RpcTimeoutError::Rpc(e)),
384      Err(_) => Err(RpcTimeoutError::Timeout),
385    }
386  }
387
388  /// Emits a signal and awaits all registered handlers.
389  ///
390  /// Handlers run concurrently and this method resolves once all handlers have completed.
391  pub async fn emit(&self, signal: Signal) {
392    // First, broadcast to any subscribers.
393    self.broadcast(signal.clone());
394
395    // Call exporters (non-blocking from the perspective of handlers).
396    for entry in self.inner.exporters.iter() {
397      (entry.value())(&signal);
398    }
399
400    if let Some(entry) = self.inner.handlers.get(&signal.id) {
401      let handlers = entry.clone();
402      drop(entry);
403
404      let futures = handlers.into_iter().map(|handler| {
405        let s = signal.clone();
406        handler(s)
407      });
408
409      let _ = join_all(futures).await;
410    }
411  }
412
413  /// Emits a signal using the global application-level arbiter.
414  pub async fn emit_app(signal: Signal) {
415    app_signals().emit(signal).await;
416  }
417
418  /// Registers a global exporter that is invoked for every emitted signal.
419  ///
420  /// Exporters are merged when routers are merged, similar to handlers.
421  pub fn register_exporter<F>(&self, exporter: F)
422  where
423    F: Fn(&Signal) + Send + Sync + 'static,
424  {
425    // Use the pointer address as a simple, best-effort key.
426    let key = Arc::into_raw(Arc::new(())) as u64;
427    let exporter: SignalExporter = Arc::new(exporter);
428    self.inner.exporters.insert(key, exporter);
429  }
430
431  /// Merges all handlers from `other` into `self`.
432  ///
433  /// This is used by router merging so that signal handlers attached to
434  /// a merged router continue to be active.
435  pub(crate) fn merge_from(&self, other: &SignalArbiter) {
436    for entry in other.inner.handlers.iter() {
437      let id = entry.key().clone();
438      let handlers = entry.value().clone();
439
440      self
441        .inner
442        .handlers
443        .entry(id)
444        .or_insert_with(Vec::new)
445        .extend(handlers);
446    }
447
448    for entry in other.inner.topics.iter() {
449      let id = entry.key().clone();
450      let sender = entry.value().clone();
451      self.inner.topics.entry(id).or_insert(sender);
452    }
453
454    for entry in other.inner.rpc.iter() {
455      let id = entry.key().clone();
456      let handler = entry.value().clone();
457      self.inner.rpc.insert(id, handler);
458    }
459
460    for entry in other.inner.exporters.iter() {
461      let key = entry.key().clone();
462      let exporter = entry.value().clone();
463      self.inner.exporters.insert(key, exporter);
464    }
465  }
466
467  /// Returns a list of known signal ids (exact topics) currently registered.
468  pub fn signal_ids(&self) -> Vec<String> {
469    self
470      .inner
471      .topics
472      .iter()
473      .filter_map(|entry| {
474        let id = entry.key();
475        if id.ends_with('*') {
476          None
477        } else {
478          Some(id.clone())
479        }
480      })
481      .collect()
482  }
483
484  /// Returns a list of known signal prefixes (topics ending with '*').
485  pub fn signal_prefixes(&self) -> Vec<String> {
486    self
487      .inner
488      .topics
489      .iter()
490      .filter_map(|entry| {
491        let id = entry.key();
492        if id.ends_with('*') {
493          Some(id.clone())
494        } else {
495          None
496        }
497      })
498      .collect()
499  }
500
501  /// Returns a list of registered RPC ids.
502  pub fn rpc_ids(&self) -> Vec<String> {
503    self.inner.rpc.iter().map(|e| e.key().clone()).collect()
504  }
505}