1use crate::entry::IServiceResolver;
2use crate::provider::ServiceProvider;
3use std::any::Any;
4use std::sync::Arc;
5
6pub struct ServiceProviderWrapper {
7 child: Arc<ServiceProvider>,
8 root: Arc<ServiceProvider>,
9}
10
11impl ServiceProviderWrapper {
12 pub fn new(child: Arc<ServiceProvider>, root: Arc<ServiceProvider>) -> Arc<Self> {
13 Arc::new(Self { child, root })
14 }
15 pub fn child(&self) -> &Arc<ServiceProvider> {
16 &self.child
17 }
18 pub fn root(&self) -> &Arc<ServiceProvider> {
19 &self.root
20 }
21
22 pub fn get<T: ?Sized + Send + Sync + 'static>(&self) -> Arc<T> {
23 self.try_get::<T>()
24 .unwrap_or_else(|| panic!("service not registered: {}", std::any::type_name::<T>()))
25 }
26 pub fn get_optional<T: ?Sized + Send + Sync + 'static>(&self) -> Option<Arc<T>> {
27 self.try_get::<T>()
28 }
29 pub fn get_keyed<T: ?Sized + Send + Sync + 'static>(&self, key: &str) -> Arc<T> {
30 self.try_get_keyed::<T>(key).unwrap_or_else(|| {
31 panic!(
32 "keyed service not found: {}:{}",
33 std::any::type_name::<T>(),
34 key
35 )
36 })
37 }
38 pub fn get_all<T: ?Sized + Send + Sync + 'static>(&self) -> Vec<Arc<T>> {
39 let mut r = self.child.get_all::<T>();
40 r.extend(self.root.get_all::<T>());
41 r
42 }
43 pub fn get_named<T: Send + Sync + 'static>(&self, name: &str) -> Option<Arc<T>> {
44 self.child
45 .get_named::<T>(name)
46 .or_else(|| self.root.get_named::<T>(name))
47 }
48 pub fn get_named_any(&self, name: &str) -> Option<Arc<dyn Any + Send + Sync>> {
49 self.child
50 .get_named_any(name)
51 .or_else(|| self.root.get_named_any(name))
52 }
53
54 fn try_get<T: ?Sized + Send + Sync + 'static>(&self) -> Option<Arc<T>> {
55 self.child
56 .get_optional::<T>()
57 .or_else(|| self.root.get_optional::<T>())
58 }
59
60 fn try_get_keyed<T: ?Sized + Send + Sync + 'static>(&self, key: &str) -> Option<Arc<T>> {
61 IServiceResolver::get_keyed_any(self.child.as_ref(), std::any::type_name::<T>(), key)
63 .or_else(|| {
64 IServiceResolver::get_keyed_any(self.root.as_ref(), std::any::type_name::<T>(), key)
65 })
66 .and_then(|arc| crate::provider::ServiceProvider::extract(arc))
67 }
68}
69
70impl IServiceResolver for ServiceProviderWrapper {
71 fn get_any(&self, key: &str) -> Option<Arc<dyn Any + Send + Sync>> {
72 IServiceResolver::get_any(self.child.as_ref(), key)
73 .or_else(|| IServiceResolver::get_any(self.root.as_ref(), key))
74 }
75 fn get_keyed_any(&self, key: &str, variant: &str) -> Option<Arc<dyn Any + Send + Sync>> {
76 IServiceResolver::get_keyed_any(self.child.as_ref(), key, variant)
77 .or_else(|| IServiceResolver::get_keyed_any(self.root.as_ref(), key, variant))
78 }
79}
80
81impl ServiceProviderWrapper {
82 pub fn rdi_register_named(&self, name: &str, service: Arc<dyn Any + Send + Sync>) {
84 self.child.rdi_register_named(name, service);
85 }
86
87 pub fn rdi_remove_named(&self, name: &str) {
89 self.child.rdi_remove_named(name);
90 }
91}
92
93#[cfg(test)]
94mod tests {
95 use super::*;
96 use crate::collection::ServiceCollection;
97 #[derive(Debug, PartialEq)]
98 struct PO;
99 #[derive(Debug, PartialEq)]
100 struct RO {
101 v: i32,
102 }
103 #[derive(Debug, PartialEq)]
104 struct B {
105 s: String,
106 }
107 #[test]
108 fn child_prio() {
109 let r = Arc::new(
110 ServiceCollection::new()
111 .singleton(|_| Arc::new(B { s: "root".into() }))
112 .build()
113 .unwrap(),
114 );
115 let c = ServiceCollection::new()
116 .singleton(|_| Arc::new(B { s: "child".into() }))
117 .build()
118 .unwrap();
119 let w = ServiceProviderWrapper::new(Arc::new(c), r);
120 assert_eq!(w.get::<B>().s, "child");
121 }
122 #[test]
123 fn root_fallback() {
124 let r = Arc::new(
125 ServiceCollection::new()
126 .singleton(|_| Arc::new(RO { v: 42 }))
127 .build()
128 .unwrap(),
129 );
130 let c = ServiceCollection::new().build().unwrap();
131 let w = ServiceProviderWrapper::new(Arc::new(c), r);
132 assert_eq!(w.get::<RO>().v, 42);
133 }
134 #[test]
135 fn child_invisible() {
136 let r = Arc::new(ServiceCollection::new().build().unwrap());
137 let c = ServiceCollection::new()
138 .singleton(|_| Arc::new(PO))
139 .build()
140 .unwrap();
141 let _w = ServiceProviderWrapper::new(Arc::new(c), r.clone());
142 assert!(r.get_optional::<PO>().is_none());
143 }
144}