1use std::any::{Any, TypeId};
2use std::collections::{HashMap, HashSet};
3use std::sync::{Arc, RwLock};
4
5use async_trait::async_trait;
6use tracing::{Instrument, error, info, warn};
7
8pub type BoxError = Box<dyn std::error::Error + Send + Sync + 'static>;
27
28pub type Result<T, E = Error> = std::result::Result<T, E>;
29
30#[derive(Debug)]
31pub enum Error {
32 Boot {
33 name: &'static str,
34 source: BoxError,
35 },
36 Validate {
37 name: &'static str,
38 source: BoxError,
39 },
40 Reload {
41 name: &'static str,
42 source: BoxError,
43 },
44 Run {
47 name: &'static str,
48 source: BoxError,
49 },
50 Recoverable {
55 name: &'static str,
56 source: BoxError,
57 },
58 Other(BoxError),
59}
60
61impl std::fmt::Display for Error {
62 fn fmt(
63 &self,
64 f: &mut std::fmt::Formatter<'_>,
65 ) -> std::fmt::Result {
66 match self {
67 Error::Boot { name, source } => {
68 write!(f, "provider '{name}' failed during boot: {source}")
69 }
70 Error::Validate { name, source } => {
71 write!(f, "provider '{name}' failed during validate: {source}")
72 }
73 Error::Reload { name, source } => {
74 write!(f, "reload of '{name}' failed: {source}")
75 }
76 Error::Run { name, source } => {
77 write!(f, "runnable '{name}' failed: {source}")
78 }
79 Error::Recoverable { name, source } => {
80 write!(f, "runnable '{name}' failed (recoverable): {source}")
81 }
82 Error::Other(e) => std::fmt::Display::fmt(e, f),
83 }
84 }
85}
86
87impl<E> From<E> for Error
94where
95 E: std::error::Error + Send + Sync + 'static,
96{
97 fn from(e: E) -> Self {
98 Error::Other(Box::new(e))
99 }
100}
101
102impl Error {
104 pub fn msg(s: impl Into<String>) -> Self {
105 #[derive(Debug)]
106 struct MsgErr(String);
107 impl std::fmt::Display for MsgErr {
108 fn fmt(
109 &self,
110 f: &mut std::fmt::Formatter<'_>,
111 ) -> std::fmt::Result {
112 std::fmt::Display::fmt(&self.0, f)
113 }
114 }
115 impl std::error::Error for MsgErr {}
116 Error::Other(Box::new(MsgErr(s.into())))
117 }
118
119 fn into_boot(
123 self,
124 name: &'static str,
125 ) -> Self {
126 match self {
127 Error::Other(source) => Error::Boot { name, source },
128 other => other,
129 }
130 }
131 fn into_validate(
132 self,
133 name: &'static str,
134 ) -> Self {
135 match self {
136 Error::Other(source) => Error::Validate { name, source },
137 other => other,
138 }
139 }
140 fn into_reload(
144 self,
145 name: &'static str,
146 ) -> Self {
147 match self {
148 Error::Other(source) => Error::Reload { name, source },
149 other => other,
150 }
151 }
152 fn into_run(
153 self,
154 name: &'static str,
155 ) -> Self {
156 match self {
157 Error::Other(source) => Error::Run { name, source },
158 Error::Recoverable { name: "", source } => Error::Recoverable { name, source },
162 other => other,
163 }
164 }
165
166 pub fn recoverable(s: impl Into<String>) -> Self {
171 #[derive(Debug)]
172 struct MsgErr(String);
173 impl std::fmt::Display for MsgErr {
174 fn fmt(
175 &self,
176 f: &mut std::fmt::Formatter<'_>,
177 ) -> std::fmt::Result {
178 std::fmt::Display::fmt(&self.0, f)
179 }
180 }
181 impl std::error::Error for MsgErr {}
182 Error::Recoverable { name: "", source: Box::new(MsgErr(s.into())) }
183 }
184}
185
186pub mod priority {
196 #[doc(hidden)]
201 pub const FIRST: u8 = 0;
202 pub const EARLY: u8 = 50;
203 pub const NORMAL: u8 = 100;
204 pub const LATE: u8 = 150;
205 #[doc(hidden)]
210 pub const LAST: u8 = u8::MAX;
211}
212
213#[derive(Clone, Debug, Default)]
219pub struct ProviderOrder {
220 before: Vec<TypeId>,
221 after: Vec<TypeId>,
222}
223
224impl ProviderOrder {
225 pub fn new() -> Self {
226 Self::default()
227 }
228
229 pub fn before<T: 'static>(mut self) -> Self {
230 self.before.push(TypeId::of::<T>());
231 self
232 }
233
234 pub fn after<T: 'static>(mut self) -> Self {
235 self.after.push(TypeId::of::<T>());
236 self
237 }
238
239 pub fn before_types(&self) -> &[TypeId] {
240 &self.before
241 }
242
243 pub fn after_types(&self) -> &[TypeId] {
244 &self.after
245 }
246}
247
248#[async_trait]
249pub trait ReloadState: Send + Sync + Sized + 'static {
250 async fn reload(&self) -> Result<()>;
251}
252
253#[async_trait]
262pub trait Reloadable<S>: Send + Sync + 'static {
263 fn priority(&self) -> Option<u8> {
269 None
270 }
271
272 async fn reload(
275 &self,
276 state: &S,
277 ) -> Result<()>;
278}
279
280#[async_trait]
293pub trait Runnable<S>: Send + Sync + 'static {
294 async fn run(
306 self: Arc<Self>,
307 state: S,
308 ) -> Result<()>;
309}
310
311#[async_trait]
365pub trait Provider<S>: Any + Send + Sync + 'static {
366 fn name(&self) -> &'static str {
368 "provider"
369 }
370
371 fn boot_priority(&self) -> Option<u8> {
375 None
376 }
377
378 fn run_priority(&self) -> Option<u8> {
381 None
382 }
383
384 fn order(&self) -> ProviderOrder {
391 ProviderOrder::default()
392 }
393
394 async fn boot(
399 &self,
400 _state: &S,
401 ) -> Result<()> {
402 Ok(())
403 }
404
405 async fn shutdown(
412 &self,
413 _state: &S,
414 ) -> Result<()> {
415 Ok(())
416 }
417
418 fn validate(
422 &self,
423 _state: &S,
424 ) -> Result<()> {
425 Ok(())
426 }
427
428 fn as_any(&self) -> &dyn Any
430 where
431 Self: Sized,
432 {
433 self
434 }
435
436 fn as_reloadable(&self) -> Option<&dyn Reloadable<S>> {
438 None
439 }
440
441 fn as_runnable(self: Arc<Self>) -> Option<Arc<dyn Runnable<S>>> {
443 None
444 }
445}
446
447pub struct Registry<S> {
452 providers: RwLock<HashMap<TypeId, Arc<dyn Provider<S>>>>,
453 by_type: RwLock<HashMap<TypeId, Arc<dyn Any + Send + Sync>>>,
454 registration_order: RwLock<Vec<TypeId>>,
455 lifecycle_order: RwLock<Option<Vec<TypeId>>>,
456}
457
458impl<S: 'static> Registry<S> {
459 pub fn new() -> Self {
461 Self {
462 providers: RwLock::new(HashMap::new()),
463 by_type: RwLock::new(HashMap::new()),
464 registration_order: RwLock::new(Vec::new()),
465 lifecycle_order: RwLock::new(None),
466 }
467 }
468
469 pub fn insert<C>(
486 &self,
487 item: Arc<C>,
488 ) -> &Self
489 where
490 C: Provider<S> + 'static,
491 {
492 let type_id = TypeId::of::<C>();
493 let any: Arc<dyn Any + Send + Sync> = item.clone();
494 let mut by_type = self.by_type.write().expect("registry by_type lock poisoned");
495 if by_type.contains_key(&type_id) {
496 warn!(
497 "⚠️ duplicate provider type '{}' — skipping registration",
498 std::any::type_name::<C>()
499 );
500 return self;
501 }
502 by_type.insert(type_id, any);
503 drop(by_type);
504
505 let it: Arc<dyn Provider<S>> = item;
506 self.providers.write().expect("registry providers lock poisoned").insert(type_id, it);
507 self.registration_order.write().expect("registry order lock poisoned").push(type_id);
508 *self.lifecycle_order.write().expect("registry lifecycle order lock poisoned") = None;
509 self
510 }
511
512 pub fn with_typed<T, R>(
514 &self,
515 f: impl FnOnce(&T) -> R,
516 ) -> Option<R>
517 where
518 T: Provider<S> + 'static,
519 {
520 let typed = self.resolve::<T>()?;
521 Some(f(typed.as_ref()))
522 }
523
524 pub fn resolve<T>(&self) -> Option<Arc<T>>
533 where
534 T: Provider<S> + 'static,
535 {
536 let any = self
537 .by_type
538 .read()
539 .expect("registry by_type lock poisoned")
540 .get(&TypeId::of::<T>())?
541 .clone();
542 Arc::downcast::<T>(any).ok()
543 }
544
545 #[allow(unused)]
547 pub fn providers(&self) -> Vec<Arc<dyn Provider<S>>> {
548 self.providers.read().expect("registry providers lock poisoned").values().cloned().collect()
549 }
550
551 fn provider_entries_snapshot(&self) -> Vec<ProviderEntry<S>> {
552 let providers = self.providers.read().expect("registry providers lock poisoned");
553 self.registration_order
554 .read()
555 .expect("registry order lock poisoned")
556 .iter()
557 .enumerate()
558 .filter_map(|(index, type_id)| {
559 providers.get(type_id).cloned().map(|provider| ProviderEntry {
560 type_id: *type_id,
561 index,
562 provider,
563 })
564 })
565 .collect()
566 }
567
568 fn lifecycle_plan(&self) -> Result<Vec<Arc<dyn Provider<S>>>> {
574 if let Some(type_ids) = self
575 .lifecycle_order
576 .read()
577 .expect("registry lifecycle order lock poisoned")
578 .as_ref()
579 .cloned()
580 {
581 return Ok(self.providers_from_type_ids(&type_ids));
582 }
583
584 let ordered = order_provider_entries(self.provider_entries_snapshot())?;
585 let type_ids = ordered.iter().map(|entry| entry.type_id).collect::<Vec<_>>();
586 let providers = ordered.iter().map(|entry| entry.provider.clone()).collect::<Vec<_>>();
587 #[cfg(debug_assertions)]
588 tracing::debug!(
589 providers = ?providers.iter().map(|provider| provider.name()).collect::<Vec<_>>(),
590 "provider lifecycle order"
591 );
592 *self.lifecycle_order.write().expect("registry lifecycle order lock poisoned") =
593 Some(type_ids);
594 Ok(providers)
595 }
596
597 fn providers_from_type_ids(
598 &self,
599 type_ids: &[TypeId],
600 ) -> Vec<Arc<dyn Provider<S>>> {
601 let providers = self.providers.read().expect("registry providers lock poisoned");
602 type_ids.iter().filter_map(|type_id| providers.get(type_id).cloned()).collect()
603 }
604
605 #[allow(unused)]
607 pub fn list_names(&self) -> Vec<&'static str> {
608 self.providers().iter().map(|c| c.name()).collect()
609 }
610
611 pub fn lifecycle_names(&self) -> Result<Vec<&'static str>> {
616 Ok(self.lifecycle_plan()?.iter().map(|provider| provider.name()).collect())
617 }
618
619 pub fn run_all(
623 &self,
624 state: S,
625 join_set: &mut tokio::task::JoinSet<Result<()>>,
626 ) -> usize
627 where
628 S: Clone + Send + 'static,
629 {
630 let mut spawned = 0usize;
631 let mut providers = self.providers();
632 providers.sort_by_key(|provider| {
633 (provider.run_priority().unwrap_or(priority::NORMAL), provider.name())
634 });
635
636 for provider in providers {
637 let Some(runnable) = provider.clone().as_runnable() else { continue };
638
639 let name = provider.name();
640 let state = state.clone();
641 join_set.spawn(
642 async move { runnable.run(state).await.map_err(|e| e.into_run(name)) }
643 .instrument(tracing::debug_span!("provider", provider = %name)),
644 );
645 spawned += 1;
646 }
647
648 spawned
649 }
650
651 pub fn validate_all(
653 &self,
654 state: &S,
655 ) -> Result<()> {
656 for provider in self.lifecycle_plan()? {
657 let name = provider.name();
658 provider.validate(state).map_err(|e| e.into_validate(name))?;
659 }
660 Ok(())
661 }
662
663 pub async fn boot_all(
664 &self,
665 state: &S,
666 ) -> Result<()> {
667 for provider in self.lifecycle_plan()? {
668 let name = provider.name();
669 if let Err(e) = provider.boot(state).await {
671 error!("❌ boot provider '{}' failed: {}", name, e);
672 return Err(e.into_boot(name));
673 }
674 }
676 Ok(())
677 }
678
679 pub async fn shutdown_all(
680 &self,
681 state: &S,
682 ) -> Result<()> {
683 let mut providers = self.lifecycle_plan()?;
684 providers.reverse();
685
686 for provider in providers {
687 let name = provider.name();
688 if let Err(e) = provider.shutdown(state).await {
689 warn!("shutdown of provider '{}' failed: {}", name, e);
690 }
691 }
692 Ok(())
693 }
694
695 pub async fn reload_one(
696 &self,
697 name: &str,
698 state: &S,
699 ) -> Result<()> {
700 let Some(provider) = self.providers().into_iter().find(|provider| provider.name() == name)
701 else {
702 return Err(Error::msg(format!(
703 "reload_by_name: no provider registered with name '{}'",
704 name
705 )));
706 };
707
708 let Some(reloadable) = provider.as_reloadable() else {
709 return Err(Error::msg(format!(
710 "reload_by_name: provider '{}' is not reloadable",
711 name
712 )));
713 };
714
715 info!("♻️ reloading service '{}'", name);
716
717 match reloadable.reload(state).await {
718 Ok(()) => {
719 info!("♻️ {} reloaded", name);
720 Ok(())
721 }
722 Err(e) => {
723 warn!("❌ reload of {} failed: {e}", name);
724 let static_name = provider.name();
726 Err(e.into_reload(static_name))
727 }
728 }
729 }
730}
731
732impl<S> Registry<S>
733where
734 S: ReloadState + 'static,
735{
736 pub async fn reload_all(
737 &self,
738 state: &S,
739 ) -> Result<()> {
740 state.reload().await?;
741
742 info!("✅ state reloaded");
743
744 for provider in self.lifecycle_plan()? {
745 let name = provider.name();
746 if let Some(reloadable) = provider.as_reloadable() {
747 if let Err(e) = reloadable.reload(state).await {
748 warn!("❌ reload of {} failed: {e}", name);
749 } else {
750 info!("♻️ {} reloaded", name);
751 }
752 }
753 }
754
755 Ok(())
756 }
757}
758
759struct ProviderEntry<S> {
760 type_id: TypeId,
761 index: usize,
762 provider: Arc<dyn Provider<S>>,
763}
764
765impl<S> Clone for ProviderEntry<S> {
766 fn clone(&self) -> Self {
767 Self { type_id: self.type_id, index: self.index, provider: self.provider.clone() }
768 }
769}
770
771fn order_provider_entries<S: 'static>(
772 entries: Vec<ProviderEntry<S>>
773) -> Result<Vec<ProviderEntry<S>>> {
774 let len = entries.len();
775 let positions: HashMap<TypeId, usize> =
776 entries.iter().enumerate().map(|(idx, entry)| (entry.type_id, idx)).collect();
777 let priorities: Vec<u8> =
778 entries.iter().map(|entry| lifecycle_priority(&entry.provider)).collect();
779 let mut outgoing: Vec<HashSet<usize>> = (0..len).map(|_| HashSet::new()).collect();
780 let mut indegree = vec![0usize; len];
781
782 let mut add_edge = |from: usize, to: usize| {
783 if from != to && outgoing[from].insert(to) {
784 indegree[to] += 1;
785 }
786 };
787
788 for (idx, entry) in entries.iter().enumerate() {
789 let order = entry.provider.order();
790 for target in order.before_types() {
791 if let Some(&target_idx) = positions.get(target) {
792 add_edge(idx, target_idx);
793 }
794 }
795 for target in order.after_types() {
796 if let Some(&target_idx) = positions.get(target) {
797 add_edge(target_idx, idx);
798 }
799 }
800 }
801
802 let mut ready: Vec<usize> = indegree
803 .iter()
804 .enumerate()
805 .filter_map(|(idx, degree)| (*degree == 0).then_some(idx))
806 .collect();
807 let mut ordered = Vec::with_capacity(len);
808
809 while !ready.is_empty() {
810 ready.sort_by_key(|idx| {
811 (priorities[*idx], entries[*idx].index, entries[*idx].provider.name())
812 });
813 let idx = ready.remove(0);
814 ordered.push(idx);
815
816 let next: Vec<_> = outgoing[idx].iter().copied().collect();
817 for target in next {
818 indegree[target] -= 1;
819 if indegree[target] == 0 {
820 ready.push(target);
821 }
822 }
823 }
824
825 if ordered.len() != len {
826 let blocked = indegree
827 .iter()
828 .enumerate()
829 .filter_map(|(idx, degree)| (*degree > 0).then_some(entries[idx].provider.name()))
830 .collect::<Vec<_>>()
831 .join(", ");
832 return Err(Error::msg(format!("provider lifecycle order cycle detected: {blocked}")));
833 }
834
835 Ok(ordered.into_iter().map(|idx| entries[idx].clone()).collect())
836}
837
838fn lifecycle_priority<S: 'static>(provider: &Arc<dyn Provider<S>>) -> u8 {
839 provider
840 .boot_priority()
841 .or_else(|| provider.as_reloadable().and_then(|reloadable| reloadable.priority()))
842 .unwrap_or(priority::NORMAL)
843}
844
845impl<S: 'static> Default for Registry<S> {
846 fn default() -> Self {
847 Self::new()
848 }
849}
850
851#[cfg(test)]
852mod tests {
853 use std::sync::Mutex;
854
855 use super::*;
856
857 #[derive(Clone, Default)]
858 struct TestState {
859 seen: Arc<Mutex<Vec<&'static str>>>,
860 }
861
862 struct DbProvider;
863 struct CacheProvider;
864 struct ApiProvider;
865
866 #[async_trait]
867 impl Provider<TestState> for DbProvider {
868 fn name(&self) -> &'static str {
869 "db"
870 }
871
872 fn validate(
873 &self,
874 state: &TestState,
875 ) -> Result<()> {
876 state.seen.lock().expect("test log poisoned").push("db");
877 Ok(())
878 }
879 }
880
881 #[async_trait]
882 impl Provider<TestState> for CacheProvider {
883 fn name(&self) -> &'static str {
884 "cache"
885 }
886
887 fn order(&self) -> ProviderOrder {
888 ProviderOrder::new().after::<DbProvider>()
889 }
890
891 fn validate(
892 &self,
893 state: &TestState,
894 ) -> Result<()> {
895 state.seen.lock().expect("test log poisoned").push("cache");
896 Ok(())
897 }
898 }
899
900 #[async_trait]
901 impl Provider<TestState> for ApiProvider {
902 fn name(&self) -> &'static str {
903 "api"
904 }
905
906 fn order(&self) -> ProviderOrder {
907 ProviderOrder::new().after::<CacheProvider>()
908 }
909
910 fn validate(
911 &self,
912 state: &TestState,
913 ) -> Result<()> {
914 state.seen.lock().expect("test log poisoned").push("api");
915 Ok(())
916 }
917 }
918
919 #[test]
920 fn lifecycle_order_uses_type_dependencies() {
921 let state = TestState::default();
922 let registry = Registry::<TestState>::new();
923
924 registry
925 .insert(Arc::new(ApiProvider))
926 .insert(Arc::new(CacheProvider))
927 .insert(Arc::new(DbProvider));
928
929 registry.validate_all(&state).expect("validation should succeed");
930
931 let seen = state.seen.lock().expect("test log poisoned").clone();
932 assert_eq!(seen, vec!["db", "cache", "api"]);
933 }
934
935 struct CycleA;
936 struct CycleB;
937
938 #[async_trait]
939 impl Provider<TestState> for CycleA {
940 fn name(&self) -> &'static str {
941 "cycle-a"
942 }
943
944 fn order(&self) -> ProviderOrder {
945 ProviderOrder::new().after::<CycleB>()
946 }
947 }
948
949 #[async_trait]
950 impl Provider<TestState> for CycleB {
951 fn name(&self) -> &'static str {
952 "cycle-b"
953 }
954
955 fn order(&self) -> ProviderOrder {
956 ProviderOrder::new().after::<CycleA>()
957 }
958 }
959
960 #[test]
961 fn lifecycle_order_rejects_cycles() {
962 let state = TestState::default();
963 let registry = Registry::<TestState>::new();
964
965 registry.insert(Arc::new(CycleA)).insert(Arc::new(CycleB));
966
967 let err = registry.validate_all(&state).expect_err("cycle must be rejected");
968 assert!(err.to_string().contains("provider lifecycle order cycle detected"));
969 }
970}