Skip to main content

runtime_rs/
registry.rs

1use std::any::{Any, TypeId};
2use std::collections::HashMap;
3use std::future::Future;
4use std::pin::Pin;
5use std::sync::{Arc, RwLock};
6
7use async_trait::async_trait;
8use tracing::{Instrument, error, info, warn};
9
10// =====================================================================
11// Error model
12// =====================================================================
13//
14// The error type is deliberately scoped to the `registry` module because
15// it is a *registry* concern (lifecycle stages: boot / validate / reload
16// / run). Other helper primitives in this crate (`gate`, `guard`) have their
17// own semantics and intentionally do not share this type.
18//
19// Consumers return any `std::error::Error + Send + Sync + 'static`
20// from their `Provider` / `Reloadable` / `Runnable` methods — the blanket
21// `From<E>` impl wraps it into `Error::Other`. The registry then re-wraps
22// `Error::Other` into the appropriate lifecycle variant (`Boot`, `Reload`,
23// `Run`, `Validate`) at the call site, so downstream logs and matches see
24// where the failure happened. Providers that want to emit a typed variant
25// themselves can construct it directly — the registry will leave it
26// untouched.
27
28pub type BoxError = Box<dyn std::error::Error + Send + Sync + 'static>;
29
30pub type Result<T, E = Error> = std::result::Result<T, E>;
31
32#[derive(Debug)]
33pub enum Error {
34    Boot {
35        name: &'static str,
36        source: BoxError
37    },
38    Validate {
39        name: &'static str,
40        source: BoxError
41    },
42    Reload {
43        name: &'static str,
44        source: BoxError
45    },
46    /// Fatal runnable failure (default). The runtime tears the worker
47    /// down so the supervisor can respawn cleanly.
48    Run {
49        name: &'static str,
50        source: BoxError
51    },
52    /// Recoverable runnable failure. The runtime logs and keeps the
53    /// worker serving — used for best-effort tasks (e.g. notify
54    /// listeners, optional integrations) where a transient or
55    /// configuration-driven failure shouldn't kill traffic.
56    Recoverable {
57        name: &'static str,
58        source: BoxError
59    },
60    Other(BoxError)
61}
62
63impl std::fmt::Display for Error {
64    fn fmt(
65        &self,
66        f: &mut std::fmt::Formatter<'_>
67    ) -> std::fmt::Result {
68        match self {
69            Error::Boot { name, source } => {
70                write!(f, "provider '{name}' failed during boot: {source}")
71            }
72            Error::Validate { name, source } => {
73                write!(f, "provider '{name}' failed during validate: {source}")
74            }
75            Error::Reload { name, source } => {
76                write!(f, "reload of '{name}' failed: {source}")
77            }
78            Error::Run { name, source } => {
79                write!(f, "runnable '{name}' failed: {source}")
80            }
81            Error::Recoverable { name, source } => {
82                write!(f, "runnable '{name}' failed (recoverable): {source}")
83            }
84            Error::Other(e) => std::fmt::Display::fmt(e, f)
85        }
86    }
87}
88
89// NOTE: `Error` intentionally does NOT implement `std::error::Error`.
90// The blanket `From<E: Error>` below requires that `Error` itself not
91// satisfy that bound (otherwise it would conflict with the core
92// `From<T> for T` blanket). Consumers that need to chain `source()` can
93// match on the variant and walk `BoxError` directly.
94
95impl<E> From<E> for Error
96where
97    E: std::error::Error + Send + Sync + 'static
98{
99    fn from(e: E) -> Self {
100        Error::Other(Box::new(e))
101    }
102}
103
104/// Construct `Error::Other` from an arbitrary message string.
105impl Error {
106    pub fn msg(s: impl Into<String>) -> Self {
107        #[derive(Debug)]
108        struct MsgErr(String);
109        impl std::fmt::Display for MsgErr {
110            fn fmt(
111                &self,
112                f: &mut std::fmt::Formatter<'_>
113            ) -> std::fmt::Result {
114                std::fmt::Display::fmt(&self.0, f)
115            }
116        }
117        impl std::error::Error for MsgErr {}
118        Error::Other(Box::new(MsgErr(s.into())))
119    }
120
121    /// If the error is `Other`, re-wrap it as `Boot { name, source }`;
122    /// otherwise leave it untouched. Used by `Registry::boot_all` to
123    /// attach lifecycle context to anonymous user errors.
124    fn into_boot(
125        self,
126        name: &'static str
127    ) -> Self {
128        match self {
129            Error::Other(source) => Error::Boot { name, source },
130            other => other
131        }
132    }
133    fn into_validate(
134        self,
135        name: &'static str
136    ) -> Self {
137        match self {
138            Error::Other(source) => Error::Validate { name, source },
139            other => other
140        }
141    }
142    /// Used by `reload_one` (targeted, fail-fast). `reload_all` is broadcast
143    /// and intentionally fail-soft — a single provider's failure should not
144    /// cancel the rest, so that path just logs a warning.
145    fn into_reload(
146        self,
147        name: &'static str
148    ) -> Self {
149        match self {
150            Error::Other(source) => Error::Reload { name, source },
151            other => other
152        }
153    }
154    fn into_run(
155        self,
156        name: &'static str
157    ) -> Self {
158        match self {
159            Error::Other(source) => Error::Run { name, source },
160            // Runnables that opt into recoverable failure construct
161            // `Recoverable` with an empty `name`; `run_all` fills in the
162            // provider name here so log lines stay attributed.
163            Error::Recoverable { name: "", source } => Error::Recoverable { name, source },
164            other => other
165        }
166    }
167
168    /// Build a recoverable runnable error from an arbitrary message.
169    /// The runtime logs this and lets the worker keep serving instead of
170    /// tearing it down. The provider `name` is filled in by `run_all`'s
171    /// wrapper, so callers only supply the message.
172    pub fn recoverable(s: impl Into<String>) -> Self {
173        #[derive(Debug)]
174        struct MsgErr(String);
175        impl std::fmt::Display for MsgErr {
176            fn fmt(
177                &self,
178                f: &mut std::fmt::Formatter<'_>
179            ) -> std::fmt::Result {
180                std::fmt::Display::fmt(&self.0, f)
181            }
182        }
183        impl std::error::Error for MsgErr {}
184        Error::Recoverable { name: "", source: Box::new(MsgErr(s.into())) }
185    }
186}
187
188// =====================================================================
189// Priority helpers
190// =====================================================================
191
192/// Shared lifecycle priority definitions for providers/reloadables.
193///
194/// Lower values run earlier.
195pub mod priority {
196    pub const FIRST: u8 = 0;
197    pub const EARLY: u8 = 50;
198    pub const NORMAL: u8 = 100;
199    pub const LATE: u8 = 150;
200    pub const LAST: u8 = u8::MAX;
201}
202
203#[async_trait]
204pub trait ReloadState: Send + Sync + Sized + 'static {
205    async fn reload(&self) -> Result<()>;
206}
207
208/// Anything that can hot-reload itself when config changes.
209///
210/// `reload()` is the same shape as `Provider::boot()` — re-read the
211/// on-disk config (use `tokio::fs`, never `std::fs` in this async path)
212/// and rebuild the runtime snapshot, publishing it through an
213/// `ArcSwap` so in-flight requests/connections see the swap atomically.
214/// Reload must NOT change which providers are registered; it only
215/// refreshes state of an already-registered provider.
216#[async_trait]
217pub trait Reloadable<S>: Send + Sync + 'static {
218    /// Optional reload priority.
219    ///
220    /// Lower values run earlier. `None` means `priority::NORMAL`.
221    fn priority(&self) -> Option<u8> {
222        None
223    }
224
225    /// Perform a synchronous reload using the current shared state.
226    /// Implementations may spawn async work internally if needed.
227    async fn reload(
228        &self,
229        state: &S
230    ) -> Result<()>;
231}
232
233/// Boxed future type returned by long-running providers (server loops/background workers).
234pub type TaskFuture = Pin<Box<dyn Future<Output = Result<()>> + Send + 'static>>;
235
236/// Capability trait for providers that produce a long-running runtime task.
237///
238/// `run()` is the ONLY place in the lifecycle for long-running work
239/// (accept loops, listeners, periodic tickers). It must NOT appear in
240/// `register()` or `Provider::boot()`.
241///
242/// Config-driven gating: if the provider is disabled at runtime (e.g.
243/// an `enabled: false` config flag, or a single-instance service whose
244/// pinned `worker_id` doesn't match this worker), the future returned
245/// here MUST short-circuit and return `Ok(())` immediately instead of
246/// starting the long task. The provider stays registered for downstream
247/// capability lookups; it just doesn't run on this process.
248pub trait Runnable<S>: Send + Sync + 'static {
249    /// Build the task future to be spawned by the bootstrap/supervisor layer.
250    ///
251    /// NOTICE (convention):
252    /// If this future returns `Err`, implementation should log contextual
253    /// failure details itself (provider/task specific metadata).
254    ///
255    /// Reason:
256    /// - Runtime layer handles lifecycle/control-flow only.
257    /// - Runtime cannot reliably attach provider-specific business context.
258    /// - Non-critical runnable errors are not centrally logged to avoid
259    ///   duplicate/no-context error lines.
260    fn run(
261        &self,
262        state: S
263    ) -> TaskFuture;
264}
265
266/// Any service that can be registered in the DI registry.
267///
268/// # Lifecycle convention
269///
270/// Each provider lives in four explicit phases. Mixing work across
271/// phase boundaries is the most common bug source — keep them strict.
272///
273/// 1. **`register()` (free fn, outside the trait)** — synchronous, no
274///    async, called once during bootstrap. Constructs the provider in
275///    a placeholder/empty state and inserts it into the registry.
276///
277///    Allowed:
278///    * Read state-level inputs (`state.run_mode()`, `state.config_dir()`)
279///      to choose what to register.
280///    * Read on-disk config synchronously *only* if the answer decides
281///      whether to register the provider at all (e.g. feature toggles,
282///      worker pinning). Use `std::fs` here — register is sync.
283///
284///    Forbidden:
285///    * Resolving other providers from the registry (they may not exist
286///      yet; ordering is settled by `boot_priority`, not by register order).
287///    * Async I/O.
288///    * Spawning tasks.
289///    * Building the operational snapshot (that's `boot()`).
290///
291/// 2. **`boot()`** — async, called after every `register()` ran, in
292///    `boot_priority` order. This is where the provider becomes usable.
293///
294///    Allowed / expected:
295///    * Resolve dependencies from the registry — by now every other
296///      `register()` has run.
297///    * Async I/O — `tokio::fs` for config, network calls, etc. Never
298///      `std::fs` (it blocks the runtime).
299///    * Build the runtime snapshot and publish it via `ArcSwap` /
300///      `ArcSwapOption` so concurrent readers see atomic swaps.
301///    * Honor disabled-state from config: leave the snapshot empty and
302///      return `Ok(())` rather than failing.
303///
304///    Forbidden:
305///    * Spawning long-running tasks. Boot must return when state is
306///      ready; the long task lives in `Runnable::run()`.
307///
308/// 3. **`Runnable::run()`** — see that trait. The only place for
309///    long-lived loops; honors disabled-state by returning `Ok(())`
310///    immediately.
311///
312/// 4. **`shutdown()`** — async best-effort cleanup after the shutdown
313///    signal has fired and before the runtime aborts any remaining
314///    runnable tasks. Use this for process-owned resources that must
315///    not leak into the next graceful boot. Default no-op.
316///
317/// Reload (`Reloadable::reload()`) follows the same shape as `boot()`.
318#[async_trait]
319pub trait Provider<S>: Any + Send + Sync + 'static {
320    /// Human-readable label for logs/diagnostics.
321    fn name(&self) -> &'static str {
322        "provider"
323    }
324
325    /// Optional boot priority. Lower values run earlier. `None` means
326    /// `priority::NORMAL`. Use `priority::FIRST` for providers others
327    /// depend on (e.g. `HttpService` publishing the parsed `http.yaml`),
328    /// `priority::AFTER` / `priority::LATE` for consumers.
329    fn boot_priority(&self) -> Option<u8> {
330        None
331    }
332
333    /// Optional runtime task start priority. Lower values run earlier.
334    /// `None` means `priority::NORMAL`.
335    fn run_priority(&self) -> Option<u8> {
336        None
337    }
338
339    /// Bootstrap-time async initialization. See the trait-level lifecycle
340    /// convention for what belongs here vs in `register()` / `run()`.
341    /// Default no-op so providers that only need `register()` insertion
342    /// don't have to implement this.
343    async fn boot(
344        &self,
345        _state: &S
346    ) -> Result<()> {
347        Ok(())
348    }
349
350    /// Graceful-shutdown cleanup hook.
351    ///
352    /// This is not a replacement for `Drop`: it is the lifecycle point
353    /// for externally named resources whose stale presence can break the
354    /// next boot, such as shm segments or lock files. Implementations
355    /// should be idempotent because shutdown paths may be re-entered.
356    async fn shutdown(
357        &self,
358        _state: &S
359    ) -> Result<()> {
360        Ok(())
361    }
362
363    /// Synchronous preflight validation.
364    ///
365    /// Call this before spawning runnable providers so bad config, missing
366    /// files, or conflicting settings fail fast before long-running work
367    /// starts. Applications may choose whether validation happens before or
368    /// after `boot_all`, depending on whether a provider needs boot-time state
369    /// to validate itself.
370    fn validate(
371        &self,
372        _state: &S
373    ) -> Result<()> {
374        Ok(())
375    }
376
377    /// Downcast hook for typed resolve APIs.
378    fn as_any(&self) -> &dyn Any
379    where
380        Self: Sized
381    {
382        self
383    }
384
385    /// Optional capability hook.
386    fn as_reloadable(&self) -> Option<&dyn Reloadable<S>> {
387        None
388    }
389
390    /// Optional capability hook.
391    fn as_runnable(&self) -> Option<&dyn Runnable<S>> {
392        None
393    }
394}
395
396/// Type-erased provider registry used for service discovery and DI-style lookup.
397/// Registration happens during bootstrap and runtime access is read-only via typed resolves.
398/// We keep the underlying maps behind `RwLock<HashMap<..>>` so registration stays simple while
399/// lookup only holds a short-lived read lock long enough to clone the stored `Arc`.
400pub struct Registry<S> {
401    providers: RwLock<HashMap<TypeId, Arc<dyn Provider<S>>>>,
402    by_type: RwLock<HashMap<TypeId, Arc<dyn Any + Send + Sync>>>
403}
404
405impl<S: 'static> Registry<S> {
406    /// Create the service with an empty registry. You can register later.
407    pub fn new() -> Self {
408        Self { providers: RwLock::new(HashMap::new()), by_type: RwLock::new(HashMap::new()) }
409    }
410
411    /// Register a provider into the registry.
412    ///
413    /// This accepts `Arc<T>` where `T: Provider`. The service is stored as a
414    /// type-erased `Arc<dyn Provider>` but continues to point to the same underlying
415    /// allocation (no new allocation is created).
416    ///
417    /// If another service with the same concrete type is already registered,
418    /// the new registration is skipped and a warning is logged.
419    ///
420    /// Returns `&Self` to allow fluent chaining:
421    ///
422    /// ```ignore
423    /// registry
424    ///     .insert(dns.clone())
425    ///     .insert(ipc.clone());
426    /// ```
427    pub fn insert<C>(
428        &self,
429        item: Arc<C>
430    ) -> &Self
431    where
432        C: Provider<S> + 'static
433    {
434        let type_id = TypeId::of::<C>();
435        let any: Arc<dyn Any + Send + Sync> = item.clone();
436        let mut by_type = self.by_type.write().expect("registry by_type lock poisoned");
437        if by_type.contains_key(&type_id) {
438            warn!(
439                "⚠️ duplicate provider type '{}' — skipping registration",
440                std::any::type_name::<C>()
441            );
442            return self;
443        }
444        by_type.insert(type_id, any);
445        drop(by_type);
446
447        let it: Arc<dyn Provider<S>> = item;
448        self.providers.write().expect("registry providers lock poisoned").insert(type_id, it);
449        self
450    }
451
452    /// Execute a closure with a concrete typed reference `&T` if the service is registered.
453    pub fn with_typed<T, R>(
454        &self,
455        f: impl FnOnce(&T) -> R
456    ) -> Option<R>
457    where
458        T: Provider<S> + 'static
459    {
460        let typed = self.resolve::<T>()?;
461        Some(f(typed.as_ref()))
462    }
463
464    /// Resolve a concrete service as an owned `Arc<T>` handle.
465    ///
466    /// This is the DI-style, high-level API: it returns a typed `Arc<T>` that
467    /// points to the same underlying allocation as the internally registered
468    /// provider (no new `Arc` allocation). The returned `Arc` is obtained by
469    /// downcasting from a type-indexed map (`TypeId`).
470    ///
471    /// Returns `None` if the type is not registered.
472    pub fn resolve<T>(&self) -> Option<Arc<T>>
473    where
474        T: Provider<S> + 'static
475    {
476        let any = self
477            .by_type
478            .read()
479            .expect("registry by_type lock poisoned")
480            .get(&TypeId::of::<T>())?
481            .clone();
482        Arc::downcast::<T>(any).ok()
483    }
484
485    /// Return a snapshot of registered providers.
486    #[allow(unused)]
487    pub fn providers(&self) -> Vec<Arc<dyn Provider<S>>> {
488        self.providers.read().expect("registry providers lock poisoned").values().cloned().collect()
489    }
490
491    /// Return the list of provider display names (for diagnostics only).
492    #[allow(unused)]
493    pub fn list_names(&self) -> Vec<&'static str> {
494        self.providers().iter().map(|c| c.name()).collect()
495    }
496
497    /// Spawn all runnable providers into the given JoinSet.
498    ///
499    /// Returns the number of tasks spawned.
500    pub fn run_all(
501        &self,
502        state: S,
503        join_set: &mut tokio::task::JoinSet<Result<()>>
504    ) -> usize
505    where
506        S: Clone + Send + 'static
507    {
508        let mut spawned = 0usize;
509        let mut providers = self.providers();
510        providers.sort_by_key(|provider| {
511            (provider.run_priority().unwrap_or(priority::NORMAL), provider.name())
512        });
513
514        for provider in providers {
515            let Some(runnable) = provider.as_runnable() else {
516                continue;
517            };
518
519            let name = provider.name();
520            let fut = runnable
521                .run(state.clone())
522                .instrument(tracing::debug_span!("provider", provider = %name));
523            join_set.spawn(async move { fut.await.map_err(|e| e.into_run(name)) });
524            spawned += 1;
525        }
526
527        spawned
528    }
529
530    /// Run `validate` hook for all registered providers.
531    pub fn validate_all(
532        &self,
533        state: &S
534    ) -> Result<()> {
535        for provider in self.providers() {
536            let name = provider.name();
537            provider.validate(state).map_err(|e| e.into_validate(name))?;
538        }
539        Ok(())
540    }
541
542    pub async fn boot_all(
543        &self,
544        state: &S
545    ) -> Result<()> {
546        let mut providers = self.providers();
547        providers.sort_by_key(|provider| {
548            (provider.boot_priority().unwrap_or(priority::NORMAL), provider.name())
549        });
550
551        for provider in providers {
552            let name = provider.name();
553            // debug!("🚀 booting provider '{}'", name);
554            if let Err(e) = provider.boot(state).await {
555                error!("❌ boot provider '{}' failed: {}", name, e);
556                return Err(e.into_boot(name));
557            }
558            // debug!("✅ provider '{}' booted", name);
559        }
560        Ok(())
561    }
562
563    pub async fn shutdown_all(
564        &self,
565        state: &S
566    ) -> Result<()> {
567        let mut providers = self.providers();
568        providers.sort_by_key(|provider| {
569            (provider.boot_priority().unwrap_or(priority::NORMAL), provider.name())
570        });
571        providers.reverse();
572
573        for provider in providers {
574            let name = provider.name();
575            if let Err(e) = provider.shutdown(state).await {
576                warn!("shutdown of provider '{}' failed: {}", name, e);
577            }
578        }
579        Ok(())
580    }
581
582    pub async fn reload_one(
583        &self,
584        name: &str,
585        state: &S
586    ) -> Result<()> {
587        let Some(provider) = self.providers().into_iter().find(|provider| provider.name() == name)
588        else {
589            return Err(Error::msg(format!(
590                "reload_by_name: no provider registered with name '{}'",
591                name
592            )));
593        };
594
595        let Some(reloadable) = provider.as_reloadable() else {
596            return Err(Error::msg(format!(
597                "reload_by_name: provider '{}' is not reloadable",
598                name
599            )));
600        };
601
602        info!("♻️  reloading service '{}'", name);
603
604        match reloadable.reload(state).await {
605            Ok(()) => {
606                info!("♻️  {} reloaded", name);
607                Ok(())
608            }
609            Err(e) => {
610                warn!("❌ reload of {} failed: {e}", name);
611                // Resolve the static name from the provider before consuming it.
612                let static_name = provider.name();
613                Err(e.into_reload(static_name))
614            }
615        }
616    }
617}
618
619impl<S> Registry<S>
620where
621    S: ReloadState + 'static
622{
623    pub async fn reload_all(
624        &self,
625        state: &S
626    ) -> Result<()> {
627        state.reload().await?;
628
629        info!("✅ state reloaded");
630
631        let mut list: Vec<(u8, &'static str, Arc<dyn Provider<S>>)> = self
632            .providers()
633            .into_iter()
634            .filter_map(|provider| {
635                let reloadable = provider.as_reloadable()?;
636                Some((reloadable.priority().unwrap_or(priority::NORMAL), provider.name(), provider))
637            })
638            .collect();
639
640        // deterministic order: priority first, name second.
641        list.sort_by_key(|(priority, name, _)| (*priority, *name));
642
643        for (_, name, provider) in list {
644            if let Some(reloadable) = provider.as_reloadable() {
645                if let Err(e) = reloadable.reload(state).await {
646                    warn!("❌ reload of {} failed: {e}", name);
647                } else {
648                    info!("♻️  {} reloaded", name);
649                }
650            }
651        }
652
653        Ok(())
654    }
655}
656
657impl<S: 'static> Default for Registry<S> {
658    fn default() -> Self {
659        Self::new()
660    }
661}