Skip to main content

runtime_rs/
registry.rs

1use std::any::{Any, TypeId};
2use std::collections::{HashMap, HashSet};
3use std::sync::{Arc, RwLock};
4
5use async_trait::async_trait;
6use tracing::{Instrument, error, info, warn};
7
8// =====================================================================
9// Error model
10// =====================================================================
11//
12// The error type is deliberately scoped to the `registry` module because
13// it is a *registry* concern (lifecycle stages: boot / validate / reload
14// / run). Other helper primitives in this crate (`gate`, `guard`) have their
15// own semantics and intentionally do not share this type.
16//
17// Consumers return any `std::error::Error + Send + Sync + 'static`
18// from their `Provider` / `Reloadable` / `Runnable` methods — the blanket
19// `From<E>` impl wraps it into `Error::Other`. The registry then re-wraps
20// `Error::Other` into the appropriate lifecycle variant (`Boot`, `Reload`,
21// `Run`, `Validate`) at the call site, so downstream logs and matches see
22// where the failure happened. Providers that want to emit a typed variant
23// themselves can construct it directly — the registry will leave it
24// untouched.
25
26pub type BoxError = Box<dyn std::error::Error + Send + Sync + 'static>;
27
28pub type Result<T, E = Error> = std::result::Result<T, E>;
29
30#[derive(Debug)]
31pub enum Error {
32    Boot {
33        name: &'static str,
34        source: BoxError,
35    },
36    Validate {
37        name: &'static str,
38        source: BoxError,
39    },
40    Reload {
41        name: &'static str,
42        source: BoxError,
43    },
44    /// Fatal runnable failure (default). The runtime tears the worker
45    /// down so the supervisor can respawn cleanly.
46    Run {
47        name: &'static str,
48        source: BoxError,
49    },
50    /// Recoverable runnable failure. The runtime logs and keeps the
51    /// worker serving — used for best-effort tasks (e.g. notify
52    /// listeners, optional integrations) where a transient or
53    /// configuration-driven failure shouldn't kill traffic.
54    Recoverable {
55        name: &'static str,
56        source: BoxError,
57    },
58    Other(BoxError),
59}
60
61impl std::fmt::Display for Error {
62    fn fmt(
63        &self,
64        f: &mut std::fmt::Formatter<'_>,
65    ) -> std::fmt::Result {
66        match self {
67            Error::Boot { name, source } => {
68                write!(f, "provider '{name}' failed during boot: {source}")
69            }
70            Error::Validate { name, source } => {
71                write!(f, "provider '{name}' failed during validate: {source}")
72            }
73            Error::Reload { name, source } => {
74                write!(f, "reload of '{name}' failed: {source}")
75            }
76            Error::Run { name, source } => {
77                write!(f, "runnable '{name}' failed: {source}")
78            }
79            Error::Recoverable { name, source } => {
80                write!(f, "runnable '{name}' failed (recoverable): {source}")
81            }
82            Error::Other(e) => std::fmt::Display::fmt(e, f),
83        }
84    }
85}
86
87// NOTE: `Error` intentionally does NOT implement `std::error::Error`.
88// The blanket `From<E: Error>` below requires that `Error` itself not
89// satisfy that bound (otherwise it would conflict with the core
90// `From<T> for T` blanket). Consumers that need to chain `source()` can
91// match on the variant and walk `BoxError` directly.
92
93impl<E> From<E> for Error
94where
95    E: std::error::Error + Send + Sync + 'static,
96{
97    fn from(e: E) -> Self {
98        Error::Other(Box::new(e))
99    }
100}
101
102/// Construct `Error::Other` from an arbitrary message string.
103impl Error {
104    pub fn msg(s: impl Into<String>) -> Self {
105        #[derive(Debug)]
106        struct MsgErr(String);
107        impl std::fmt::Display for MsgErr {
108            fn fmt(
109                &self,
110                f: &mut std::fmt::Formatter<'_>,
111            ) -> std::fmt::Result {
112                std::fmt::Display::fmt(&self.0, f)
113            }
114        }
115        impl std::error::Error for MsgErr {}
116        Error::Other(Box::new(MsgErr(s.into())))
117    }
118
119    /// If the error is `Other`, re-wrap it as `Boot { name, source }`;
120    /// otherwise leave it untouched. Used by `Registry::boot_all` to
121    /// attach lifecycle context to anonymous user errors.
122    fn into_boot(
123        self,
124        name: &'static str,
125    ) -> Self {
126        match self {
127            Error::Other(source) => Error::Boot { name, source },
128            other => other,
129        }
130    }
131    fn into_validate(
132        self,
133        name: &'static str,
134    ) -> Self {
135        match self {
136            Error::Other(source) => Error::Validate { name, source },
137            other => other,
138        }
139    }
140    /// Used by `reload_one` (targeted, fail-fast). `reload_all` is broadcast
141    /// and intentionally fail-soft — a single provider's failure should not
142    /// cancel the rest, so that path just logs a warning.
143    fn into_reload(
144        self,
145        name: &'static str,
146    ) -> Self {
147        match self {
148            Error::Other(source) => Error::Reload { name, source },
149            other => other,
150        }
151    }
152    fn into_run(
153        self,
154        name: &'static str,
155    ) -> Self {
156        match self {
157            Error::Other(source) => Error::Run { name, source },
158            // Runnables that opt into recoverable failure construct
159            // `Recoverable` with an empty `name`; `run_all` fills in the
160            // provider name here so log lines stay attributed.
161            Error::Recoverable { name: "", source } => Error::Recoverable { name, source },
162            other => other,
163        }
164    }
165
166    /// Build a recoverable runnable error from an arbitrary message.
167    /// The runtime logs this and lets the worker keep serving instead of
168    /// tearing it down. The provider `name` is filled in by `run_all`'s
169    /// wrapper, so callers only supply the message.
170    pub fn recoverable(s: impl Into<String>) -> Self {
171        #[derive(Debug)]
172        struct MsgErr(String);
173        impl std::fmt::Display for MsgErr {
174            fn fmt(
175                &self,
176                f: &mut std::fmt::Formatter<'_>,
177            ) -> std::fmt::Result {
178                std::fmt::Display::fmt(&self.0, f)
179            }
180        }
181        impl std::error::Error for MsgErr {}
182        Error::Recoverable { name: "", source: Box::new(MsgErr(s.into())) }
183    }
184}
185
186// =====================================================================
187// Priority helpers
188// =====================================================================
189
190/// Shared lifecycle priority definitions for providers/reloadables.
191///
192/// Lower values run earlier among providers that are otherwise ready.
193/// Prefer `ProviderOrder` for real dependencies; priorities are only
194/// coarse tie-breakers for legacy/simple cases.
195pub mod priority {
196    /// Reserved floor for workspace-internal root providers.
197    ///
198    /// Ordinary providers should use `EARLY`, `NORMAL`, `LATE`, or explicit
199    /// `ProviderOrder` edges instead of depending on this extreme value.
200    #[doc(hidden)]
201    pub const FIRST: u8 = 0;
202    pub const EARLY: u8 = 50;
203    pub const NORMAL: u8 = 100;
204    pub const LATE: u8 = 150;
205    /// Reserved ceiling for final workspace-internal lifecycle providers.
206    ///
207    /// Other providers should use `LATE` plus explicit `ProviderOrder` edges
208    /// when they need to be late.
209    #[doc(hidden)]
210    pub const LAST: u8 = u8::MAX;
211}
212
213/// Type-based lifecycle ordering hints.
214///
215/// Numeric priorities still provide a coarse tie-breaker. `ProviderOrder`
216/// adds explicit relationships between provider concrete types, so code can
217/// say "run me before `T`" without relying on magic numbers or provider names.
218#[derive(Clone, Debug, Default)]
219pub struct ProviderOrder {
220    before: Vec<TypeId>,
221    after: Vec<TypeId>,
222}
223
224impl ProviderOrder {
225    pub fn new() -> Self {
226        Self::default()
227    }
228
229    pub fn before<T: 'static>(mut self) -> Self {
230        self.before.push(TypeId::of::<T>());
231        self
232    }
233
234    pub fn after<T: 'static>(mut self) -> Self {
235        self.after.push(TypeId::of::<T>());
236        self
237    }
238
239    pub fn before_types(&self) -> &[TypeId] {
240        &self.before
241    }
242
243    pub fn after_types(&self) -> &[TypeId] {
244        &self.after
245    }
246}
247
248#[async_trait]
249pub trait ReloadState: Send + Sync + Sized + 'static {
250    async fn reload(&self) -> Result<()>;
251}
252
253/// Anything that can hot-reload itself when config changes.
254///
255/// `reload()` is the same shape as `Provider::boot()` — re-read the
256/// on-disk config (use `tokio::fs`, never `std::fs` in this async path)
257/// and rebuild the runtime snapshot, publishing it through an
258/// `ArcSwap` so in-flight requests/connections see the swap atomically.
259/// Reload must NOT change which providers are registered; it only
260/// refreshes state of an already-registered provider.
261#[async_trait]
262pub trait Reloadable<S>: Send + Sync + 'static {
263    /// Optional reload priority.
264    ///
265    /// Lower values run earlier among otherwise-ready providers. `None`
266    /// means `priority::NORMAL`. Prefer `Provider::order()` for real
267    /// dependency relationships.
268    fn priority(&self) -> Option<u8> {
269        None
270    }
271
272    /// Perform a synchronous reload using the current shared state.
273    /// Implementations may spawn async work internally if needed.
274    async fn reload(
275        &self,
276        state: &S,
277    ) -> Result<()>;
278}
279
280/// Capability trait for providers that produce a long-running runtime task.
281///
282/// `run()` is the ONLY place in the lifecycle for long-running work
283/// (accept loops, listeners, periodic tickers). It must NOT appear in
284/// `register()` or `Provider::boot()`.
285///
286/// Config-driven gating: if the provider is disabled at runtime (e.g.
287/// an `enabled: false` config flag, or a single-instance service whose
288/// pinned `worker_id` doesn't match this worker), this method MUST
289/// short-circuit and return `Ok(())` immediately instead of starting the
290/// long task. The provider stays registered for downstream capability
291/// lookups; it just doesn't run on this process.
292#[async_trait]
293pub trait Runnable<S>: Send + Sync + 'static {
294    /// Run the long-lived provider task spawned by the bootstrap/supervisor layer.
295    ///
296    /// NOTICE (convention):
297    /// If this future returns `Err`, implementation should log contextual
298    /// failure details itself (provider/task specific metadata).
299    ///
300    /// Reason:
301    /// - Runtime layer handles lifecycle/control-flow only.
302    /// - Runtime cannot reliably attach provider-specific business context.
303    /// - Non-critical runnable errors are not centrally logged to avoid
304    ///   duplicate/no-context error lines.
305    async fn run(
306        self: Arc<Self>,
307        state: S,
308    ) -> Result<()>;
309}
310
311/// Any service that can be registered in the DI registry.
312///
313/// # Lifecycle convention
314///
315/// Each provider lives in four explicit phases. Mixing work across
316/// phase boundaries is the most common bug source — keep them strict.
317///
318/// 1. **`register()` (free fn, outside the trait)** — synchronous, no
319///    async, called once during bootstrap. Constructs the provider in
320///    a placeholder/empty state and inserts it into the registry.
321///
322///    Allowed:
323///    * Read state-level inputs (`state.run_mode()`, `state.config_dir()`)
324///      to choose what to register.
325///    * Read on-disk config synchronously *only* if the answer decides
326///      whether to register the provider at all (e.g. feature toggles,
327///      worker pinning). Use `std::fs` here — register is sync.
328///
329///    Forbidden:
330///    * Resolving other providers from the registry (they may not exist
331///      yet; ordering is settled by `Provider::order()` and coarse
332///      priority, not by register order).
333///    * Async I/O.
334///    * Spawning tasks.
335///    * Building the operational snapshot (that's `boot()`).
336///
337/// 2. **`boot()`** — async, called after every `register()` ran, in
338///    lifecycle order. This is where the provider becomes usable.
339///
340///    Allowed / expected:
341///    * Resolve dependencies from the registry — by now every other
342///      `register()` has run.
343///    * Async I/O — `tokio::fs` for config, network calls, etc. Never
344///      `std::fs` (it blocks the runtime).
345///    * Build the runtime snapshot and publish it via `ArcSwap` /
346///      `ArcSwapOption` so concurrent readers see atomic swaps.
347///    * Honor disabled-state from config: leave the snapshot empty and
348///      return `Ok(())` rather than failing.
349///
350///    Forbidden:
351///    * Spawning long-running tasks. Boot must return when state is
352///      ready; the long task lives in `Runnable::run()`.
353///
354/// 3. **`Runnable::run()`** — see that trait. The only place for
355///    long-lived loops; honors disabled-state by returning `Ok(())`
356///    immediately.
357///
358/// 4. **`shutdown()`** — async best-effort cleanup after the shutdown
359///    signal has fired and before the runtime aborts any remaining
360///    runnable tasks. Use this for process-owned resources that must
361///    not leak into the next graceful boot. Default no-op.
362///
363/// Reload (`Reloadable::reload()`) follows the same shape as `boot()`.
364#[async_trait]
365pub trait Provider<S>: Any + Send + Sync + 'static {
366    /// Human-readable label for logs/diagnostics.
367    fn name(&self) -> &'static str {
368        "provider"
369    }
370
371    /// Optional boot priority. Lower values run earlier among otherwise-ready
372    /// providers. `None` means `priority::NORMAL`. Prefer `Provider::order()`
373    /// for dependency relationships; priority is only a coarse tie-breaker.
374    fn boot_priority(&self) -> Option<u8> {
375        None
376    }
377
378    /// Optional runtime task start priority. Lower values run earlier.
379    /// `None` means `priority::NORMAL`.
380    fn run_priority(&self) -> Option<u8> {
381        None
382    }
383
384    /// Optional type-based boot/reload ordering hints.
385    ///
386    /// The registry builds one ordered lifecycle plan and uses it for boot,
387    /// validate, shutdown, and reload. Reload skips providers that are not
388    /// `Reloadable`, but dependency relationships remain the same: reload is
389    /// a boot emulation on a live process.
390    fn order(&self) -> ProviderOrder {
391        ProviderOrder::default()
392    }
393
394    /// Bootstrap-time async initialization. See the trait-level lifecycle
395    /// convention for what belongs here vs in `register()` / `run()`.
396    /// Default no-op so providers that only need `register()` insertion
397    /// don't have to implement this.
398    async fn boot(
399        &self,
400        _state: &S,
401    ) -> Result<()> {
402        Ok(())
403    }
404
405    /// Graceful-shutdown cleanup hook.
406    ///
407    /// This is not a replacement for `Drop`: it is the lifecycle point
408    /// for externally named resources whose stale presence can break the
409    /// next boot, such as shm segments or lock files. Implementations
410    /// should be idempotent because shutdown paths may be re-entered.
411    async fn shutdown(
412        &self,
413        _state: &S,
414    ) -> Result<()> {
415        Ok(())
416    }
417
418    /// Synchronous preflight validation. Runs in the config-check / startup
419    /// validation phase before any `boot()` to fail fast on bad config
420    /// (missing files, conflicting settings) without touching the registry.
421    fn validate(
422        &self,
423        _state: &S,
424    ) -> Result<()> {
425        Ok(())
426    }
427
428    /// Downcast hook for typed resolve APIs.
429    fn as_any(&self) -> &dyn Any
430    where
431        Self: Sized,
432    {
433        self
434    }
435
436    /// Optional capability hook.
437    fn as_reloadable(&self) -> Option<&dyn Reloadable<S>> {
438        None
439    }
440
441    /// Optional capability hook.
442    fn as_runnable(self: Arc<Self>) -> Option<Arc<dyn Runnable<S>>> {
443        None
444    }
445}
446
447/// Type-erased provider registry used for service discovery and DI-style lookup.
448/// Registration happens during bootstrap and runtime access is read-only via typed resolves.
449/// We keep the underlying maps behind `RwLock<HashMap<..>>` so registration stays simple while
450/// lookup only holds a short-lived read lock long enough to clone the stored `Arc`.
451pub struct Registry<S> {
452    providers: RwLock<HashMap<TypeId, Arc<dyn Provider<S>>>>,
453    by_type: RwLock<HashMap<TypeId, Arc<dyn Any + Send + Sync>>>,
454    registration_order: RwLock<Vec<TypeId>>,
455    lifecycle_order: RwLock<Option<Vec<TypeId>>>,
456}
457
458impl<S: 'static> Registry<S> {
459    /// Create the service with an empty registry. You can register later.
460    pub fn new() -> Self {
461        Self {
462            providers: RwLock::new(HashMap::new()),
463            by_type: RwLock::new(HashMap::new()),
464            registration_order: RwLock::new(Vec::new()),
465            lifecycle_order: RwLock::new(None),
466        }
467    }
468
469    /// Register a provider into the registry.
470    ///
471    /// This accepts `Arc<T>` where `T: Provider`. The service is stored as a
472    /// type-erased `Arc<dyn Provider>` but continues to point to the same underlying
473    /// allocation (no new allocation is created).
474    ///
475    /// If another service with the same concrete type is already registered,
476    /// the new registration is skipped and a warning is logged.
477    ///
478    /// Returns `&Self` to allow fluent chaining:
479    ///
480    /// ```ignore
481    /// registry
482    ///     .insert(dns.clone())
483    ///     .insert(ipc.clone());
484    /// ```
485    pub fn insert<C>(
486        &self,
487        item: Arc<C>,
488    ) -> &Self
489    where
490        C: Provider<S> + 'static,
491    {
492        let type_id = TypeId::of::<C>();
493        let any: Arc<dyn Any + Send + Sync> = item.clone();
494        let mut by_type = self.by_type.write().expect("registry by_type lock poisoned");
495        if by_type.contains_key(&type_id) {
496            warn!(
497                "⚠️ duplicate provider type '{}' — skipping registration",
498                std::any::type_name::<C>()
499            );
500            return self;
501        }
502        by_type.insert(type_id, any);
503        drop(by_type);
504
505        let it: Arc<dyn Provider<S>> = item;
506        self.providers.write().expect("registry providers lock poisoned").insert(type_id, it);
507        self.registration_order.write().expect("registry order lock poisoned").push(type_id);
508        *self.lifecycle_order.write().expect("registry lifecycle order lock poisoned") = None;
509        self
510    }
511
512    /// Execute a closure with a concrete typed reference `&T` if the service is registered.
513    pub fn with_typed<T, R>(
514        &self,
515        f: impl FnOnce(&T) -> R,
516    ) -> Option<R>
517    where
518        T: Provider<S> + 'static,
519    {
520        let typed = self.resolve::<T>()?;
521        Some(f(typed.as_ref()))
522    }
523
524    /// Resolve a concrete service as an owned `Arc<T>` handle.
525    ///
526    /// This is the DI-style, high-level API: it returns a typed `Arc<T>` that
527    /// points to the same underlying allocation as the internally registered
528    /// provider (no new `Arc` allocation). The returned `Arc` is obtained by
529    /// downcasting from a type-indexed map (`TypeId`).
530    ///
531    /// Returns `None` if the type is not registered.
532    pub fn resolve<T>(&self) -> Option<Arc<T>>
533    where
534        T: Provider<S> + 'static,
535    {
536        let any = self
537            .by_type
538            .read()
539            .expect("registry by_type lock poisoned")
540            .get(&TypeId::of::<T>())?
541            .clone();
542        Arc::downcast::<T>(any).ok()
543    }
544
545    /// Return a snapshot of registered providers.
546    #[allow(unused)]
547    pub fn providers(&self) -> Vec<Arc<dyn Provider<S>>> {
548        self.providers.read().expect("registry providers lock poisoned").values().cloned().collect()
549    }
550
551    fn provider_entries_snapshot(&self) -> Vec<ProviderEntry<S>> {
552        let providers = self.providers.read().expect("registry providers lock poisoned");
553        self.registration_order
554            .read()
555            .expect("registry order lock poisoned")
556            .iter()
557            .enumerate()
558            .filter_map(|(index, type_id)| {
559                providers.get(type_id).cloned().map(|provider| ProviderEntry {
560                    type_id: *type_id,
561                    index,
562                    provider,
563                })
564            })
565            .collect()
566    }
567
568    /// Return the cached lifecycle plan, building it once if needed.
569    ///
570    /// The plan is invalidated on `insert()`. Normal lifecycle phases reuse
571    /// the same known list, so reload is a boot emulation over the same
572    /// provider order instead of a second ordering universe.
573    fn lifecycle_plan(&self) -> Result<Vec<Arc<dyn Provider<S>>>> {
574        if let Some(type_ids) = self
575            .lifecycle_order
576            .read()
577            .expect("registry lifecycle order lock poisoned")
578            .as_ref()
579            .cloned()
580        {
581            return Ok(self.providers_from_type_ids(&type_ids));
582        }
583
584        let ordered = order_provider_entries(self.provider_entries_snapshot())?;
585        let type_ids = ordered.iter().map(|entry| entry.type_id).collect::<Vec<_>>();
586        let providers = ordered.iter().map(|entry| entry.provider.clone()).collect::<Vec<_>>();
587        #[cfg(debug_assertions)]
588        tracing::debug!(
589            providers = ?providers.iter().map(|provider| provider.name()).collect::<Vec<_>>(),
590            "provider lifecycle order"
591        );
592        *self.lifecycle_order.write().expect("registry lifecycle order lock poisoned") =
593            Some(type_ids);
594        Ok(providers)
595    }
596
597    fn providers_from_type_ids(
598        &self,
599        type_ids: &[TypeId],
600    ) -> Vec<Arc<dyn Provider<S>>> {
601        let providers = self.providers.read().expect("registry providers lock poisoned");
602        type_ids.iter().filter_map(|type_id| providers.get(type_id).cloned()).collect()
603    }
604
605    /// Return the list of provider display names (for diagnostics only).
606    #[allow(unused)]
607    pub fn list_names(&self) -> Vec<&'static str> {
608        self.providers().iter().map(|c| c.name()).collect()
609    }
610
611    /// Return provider display names in lifecycle order.
612    ///
613    /// This is useful for diagnostics and startup logging before running
614    /// `boot_all()`.
615    pub fn lifecycle_names(&self) -> Result<Vec<&'static str>> {
616        Ok(self.lifecycle_plan()?.iter().map(|provider| provider.name()).collect())
617    }
618
619    /// Spawn all runnable providers into the given JoinSet.
620    ///
621    /// Returns the number of tasks spawned.
622    pub fn run_all(
623        &self,
624        state: S,
625        join_set: &mut tokio::task::JoinSet<Result<()>>,
626    ) -> usize
627    where
628        S: Clone + Send + 'static,
629    {
630        let mut spawned = 0usize;
631        let mut providers = self.providers();
632        providers.sort_by_key(|provider| {
633            (provider.run_priority().unwrap_or(priority::NORMAL), provider.name())
634        });
635
636        for provider in providers {
637            let Some(runnable) = provider.clone().as_runnable() else { continue };
638
639            let name = provider.name();
640            let state = state.clone();
641            join_set.spawn(
642                async move { runnable.run(state).await.map_err(|e| e.into_run(name)) }
643                    .instrument(tracing::debug_span!("provider", provider = %name)),
644            );
645            spawned += 1;
646        }
647
648        spawned
649    }
650
651    /// Run `validate` hook for all registered providers.
652    pub fn validate_all(
653        &self,
654        state: &S,
655    ) -> Result<()> {
656        for provider in self.lifecycle_plan()? {
657            let name = provider.name();
658            provider.validate(state).map_err(|e| e.into_validate(name))?;
659        }
660        Ok(())
661    }
662
663    pub async fn boot_all(
664        &self,
665        state: &S,
666    ) -> Result<()> {
667        for provider in self.lifecycle_plan()? {
668            let name = provider.name();
669            // debug!("🚀 booting provider '{}'", name);
670            if let Err(e) = provider.boot(state).await {
671                error!("❌ boot provider '{}' failed: {}", name, e);
672                return Err(e.into_boot(name));
673            }
674            // debug!("✅ provider '{}' booted", name);
675        }
676        Ok(())
677    }
678
679    pub async fn shutdown_all(
680        &self,
681        state: &S,
682    ) -> Result<()> {
683        let mut providers = self.lifecycle_plan()?;
684        providers.reverse();
685
686        for provider in providers {
687            let name = provider.name();
688            if let Err(e) = provider.shutdown(state).await {
689                warn!("shutdown of provider '{}' failed: {}", name, e);
690            }
691        }
692        Ok(())
693    }
694
695    pub async fn reload_one(
696        &self,
697        name: &str,
698        state: &S,
699    ) -> Result<()> {
700        let Some(provider) = self.providers().into_iter().find(|provider| provider.name() == name)
701        else {
702            return Err(Error::msg(format!(
703                "reload_by_name: no provider registered with name '{}'",
704                name
705            )));
706        };
707
708        let Some(reloadable) = provider.as_reloadable() else {
709            return Err(Error::msg(format!(
710                "reload_by_name: provider '{}' is not reloadable",
711                name
712            )));
713        };
714
715        info!("♻️  reloading service '{}'", name);
716
717        match reloadable.reload(state).await {
718            Ok(()) => {
719                info!("♻️  {} reloaded", name);
720                Ok(())
721            }
722            Err(e) => {
723                warn!("❌ reload of {} failed: {e}", name);
724                // Resolve the static name from the provider before consuming it.
725                let static_name = provider.name();
726                Err(e.into_reload(static_name))
727            }
728        }
729    }
730}
731
732impl<S> Registry<S>
733where
734    S: ReloadState + 'static,
735{
736    pub async fn reload_all(
737        &self,
738        state: &S,
739    ) -> Result<()> {
740        state.reload().await?;
741
742        info!("✅ state reloaded");
743
744        for provider in self.lifecycle_plan()? {
745            let name = provider.name();
746            if let Some(reloadable) = provider.as_reloadable() {
747                if let Err(e) = reloadable.reload(state).await {
748                    warn!("❌ reload of {} failed: {e}", name);
749                } else {
750                    info!("♻️  {} reloaded", name);
751                }
752            }
753        }
754
755        Ok(())
756    }
757}
758
759struct ProviderEntry<S> {
760    type_id: TypeId,
761    index: usize,
762    provider: Arc<dyn Provider<S>>,
763}
764
765impl<S> Clone for ProviderEntry<S> {
766    fn clone(&self) -> Self {
767        Self { type_id: self.type_id, index: self.index, provider: self.provider.clone() }
768    }
769}
770
771fn order_provider_entries<S: 'static>(
772    entries: Vec<ProviderEntry<S>>
773) -> Result<Vec<ProviderEntry<S>>> {
774    let len = entries.len();
775    let positions: HashMap<TypeId, usize> =
776        entries.iter().enumerate().map(|(idx, entry)| (entry.type_id, idx)).collect();
777    let priorities: Vec<u8> =
778        entries.iter().map(|entry| lifecycle_priority(&entry.provider)).collect();
779    let mut outgoing: Vec<HashSet<usize>> = (0..len).map(|_| HashSet::new()).collect();
780    let mut indegree = vec![0usize; len];
781
782    let mut add_edge = |from: usize, to: usize| {
783        if from != to && outgoing[from].insert(to) {
784            indegree[to] += 1;
785        }
786    };
787
788    for (idx, entry) in entries.iter().enumerate() {
789        let order = entry.provider.order();
790        for target in order.before_types() {
791            if let Some(&target_idx) = positions.get(target) {
792                add_edge(idx, target_idx);
793            }
794        }
795        for target in order.after_types() {
796            if let Some(&target_idx) = positions.get(target) {
797                add_edge(target_idx, idx);
798            }
799        }
800    }
801
802    let mut ready: Vec<usize> = indegree
803        .iter()
804        .enumerate()
805        .filter_map(|(idx, degree)| (*degree == 0).then_some(idx))
806        .collect();
807    let mut ordered = Vec::with_capacity(len);
808
809    while !ready.is_empty() {
810        ready.sort_by_key(|idx| {
811            (priorities[*idx], entries[*idx].index, entries[*idx].provider.name())
812        });
813        let idx = ready.remove(0);
814        ordered.push(idx);
815
816        let next: Vec<_> = outgoing[idx].iter().copied().collect();
817        for target in next {
818            indegree[target] -= 1;
819            if indegree[target] == 0 {
820                ready.push(target);
821            }
822        }
823    }
824
825    if ordered.len() != len {
826        let blocked = indegree
827            .iter()
828            .enumerate()
829            .filter_map(|(idx, degree)| (*degree > 0).then_some(entries[idx].provider.name()))
830            .collect::<Vec<_>>()
831            .join(", ");
832        return Err(Error::msg(format!("provider lifecycle order cycle detected: {blocked}")));
833    }
834
835    Ok(ordered.into_iter().map(|idx| entries[idx].clone()).collect())
836}
837
838fn lifecycle_priority<S: 'static>(provider: &Arc<dyn Provider<S>>) -> u8 {
839    provider
840        .boot_priority()
841        .or_else(|| provider.as_reloadable().and_then(|reloadable| reloadable.priority()))
842        .unwrap_or(priority::NORMAL)
843}
844
845impl<S: 'static> Default for Registry<S> {
846    fn default() -> Self {
847        Self::new()
848    }
849}
850
851#[cfg(test)]
852mod tests {
853    use std::sync::Mutex;
854
855    use super::*;
856
857    #[derive(Clone, Default)]
858    struct TestState {
859        seen: Arc<Mutex<Vec<&'static str>>>,
860    }
861
862    struct DbProvider;
863    struct CacheProvider;
864    struct ApiProvider;
865
866    #[async_trait]
867    impl Provider<TestState> for DbProvider {
868        fn name(&self) -> &'static str {
869            "db"
870        }
871
872        fn validate(
873            &self,
874            state: &TestState,
875        ) -> Result<()> {
876            state.seen.lock().expect("test log poisoned").push("db");
877            Ok(())
878        }
879    }
880
881    #[async_trait]
882    impl Provider<TestState> for CacheProvider {
883        fn name(&self) -> &'static str {
884            "cache"
885        }
886
887        fn order(&self) -> ProviderOrder {
888            ProviderOrder::new().after::<DbProvider>()
889        }
890
891        fn validate(
892            &self,
893            state: &TestState,
894        ) -> Result<()> {
895            state.seen.lock().expect("test log poisoned").push("cache");
896            Ok(())
897        }
898    }
899
900    #[async_trait]
901    impl Provider<TestState> for ApiProvider {
902        fn name(&self) -> &'static str {
903            "api"
904        }
905
906        fn order(&self) -> ProviderOrder {
907            ProviderOrder::new().after::<CacheProvider>()
908        }
909
910        fn validate(
911            &self,
912            state: &TestState,
913        ) -> Result<()> {
914            state.seen.lock().expect("test log poisoned").push("api");
915            Ok(())
916        }
917    }
918
919    #[test]
920    fn lifecycle_order_uses_type_dependencies() {
921        let state = TestState::default();
922        let registry = Registry::<TestState>::new();
923
924        registry
925            .insert(Arc::new(ApiProvider))
926            .insert(Arc::new(CacheProvider))
927            .insert(Arc::new(DbProvider));
928
929        registry.validate_all(&state).expect("validation should succeed");
930
931        let seen = state.seen.lock().expect("test log poisoned").clone();
932        assert_eq!(seen, vec!["db", "cache", "api"]);
933    }
934
935    struct CycleA;
936    struct CycleB;
937
938    #[async_trait]
939    impl Provider<TestState> for CycleA {
940        fn name(&self) -> &'static str {
941            "cycle-a"
942        }
943
944        fn order(&self) -> ProviderOrder {
945            ProviderOrder::new().after::<CycleB>()
946        }
947    }
948
949    #[async_trait]
950    impl Provider<TestState> for CycleB {
951        fn name(&self) -> &'static str {
952            "cycle-b"
953        }
954
955        fn order(&self) -> ProviderOrder {
956            ProviderOrder::new().after::<CycleA>()
957        }
958    }
959
960    #[test]
961    fn lifecycle_order_rejects_cycles() {
962        let state = TestState::default();
963        let registry = Registry::<TestState>::new();
964
965        registry.insert(Arc::new(CycleA)).insert(Arc::new(CycleB));
966
967        let err = registry.validate_all(&state).expect_err("cycle must be rejected");
968        assert!(err.to_string().contains("provider lifecycle order cycle detected"));
969    }
970}