studiole_di/
service_provider.rs1use crate::prelude::*;
4
5#[derive(Clone)]
7pub struct ServiceProvider {
8 pub(crate) registry: Arc<ServiceRegistry>,
10}
11
12impl ServiceProvider {
13 pub fn get<T: Send + Sync + 'static>(&self) -> Result<Arc<T>, Report<ResolveError>> {
15 let type_name = type_name::<T>();
16 trace!(type_name, "Resolving service");
17 let type_id = TypeId::of::<T>();
18
19 if let Some(dynamic) = self.get_cached(type_id) {
20 return Ok(dynamic.expect_downcast::<T>());
21 }
22
23 let registration = self.get_registration(type_id, type_name)?;
24 #[cfg(feature = "async")]
25 if registration.is_async {
26 return Err(Report::new(ResolveError::Async)).attach("type", type_name);
27 }
28 let dynamic = (registration.factory)(self)?;
29 self.cache_if_singleton(type_id, registration.scope, &dynamic);
30 Ok(dynamic.expect_downcast::<T>())
31 }
32
33 pub(crate) fn get_cached(&self, type_id: TypeId) -> Option<Arc<dyn Any + Send + Sync>> {
35 let instances = self
36 .registry
37 .instances
38 .lock()
39 .expect("should be able to lock instances");
40 instances.get(&type_id).map(Arc::clone)
41 }
42
43 pub(crate) fn get_registration(
45 &self,
46 type_id: TypeId,
47 type_name: &'static str,
48 ) -> Result<&Registration, Report<ResolveError>> {
49 self.registry
50 .factories
51 .get(&type_id)
52 .ok_or_else(|| Report::new(ResolveError::NotFound))
53 .attach("type", type_name)
54 }
55
56 pub(crate) fn cache_if_singleton(
58 &self,
59 type_id: TypeId,
60 scope: Scope,
61 dynamic: &Arc<dyn Any + Send + Sync>,
62 ) {
63 if scope == Scope::Singleton {
64 self.registry
65 .instances
66 .lock()
67 .expect("should be able to lock instances")
68 .insert(type_id, Arc::clone(dynamic));
69 }
70 }
71}
72
73#[derive(Clone, Copy, Debug, Eq, Error, PartialEq)]
75pub enum ResolveError {
76 #[error("Service not registered")]
78 NotFound,
79 #[error("Factory failed to construct service")]
81 Factory,
82 #[cfg(feature = "async")]
84 #[error("Service requires async resolution")]
85 Async,
86}
87
88#[cfg(test)]
89mod tests {
90 use super::*;
91
92 #[test]
93 fn singleton_shares_state() {
94 let services = ServiceBuilder::new().with_type::<MemoryCache>().build();
96
97 let first = services.get::<MemoryCache>().expect("should resolve");
99 first.set("key", "hello");
100 let second = services.get::<MemoryCache>().expect("should resolve");
101
102 assert_eq!(second.get("key"), Some(String::from("hello")));
104 }
105
106 #[test]
107 fn transient_does_not_share_state() {
108 let services = ServiceBuilder::new()
110 .with_type_transient::<MemoryCache>()
111 .build();
112
113 let first = services.get::<MemoryCache>().expect("should resolve");
115 first.set("key", "hello");
116 let second = services.get::<MemoryCache>().expect("should resolve");
117
118 assert_eq!(second.get("key"), None);
120 }
121
122 #[test]
123 fn unregistered_type_returns_not_found() {
124 let services = ServiceBuilder::new().build();
126
127 let result = services.get::<Config>();
129
130 assert!(result.is_err());
132 }
133
134 #[test]
135 fn resolve_instance() {
136 let services = ServiceBuilder::new()
138 .with_instance(Config { port: 3000 })
139 .build();
140
141 let config = services.get::<Config>().expect("should resolve");
143
144 assert_eq!(config.port, 3000);
146 }
147
148 #[test]
149 fn cloned_provider_shares_singleton() {
150 let services = ServiceBuilder::new().with_type::<MemoryCache>().build();
152
153 let first = services.get::<MemoryCache>().expect("should resolve");
155 first.set("key", "hello");
156 let cloned = services.clone();
157 let second = cloned.get::<MemoryCache>().expect("should resolve");
158
159 assert_eq!(second.get("key"), Some(String::from("hello")));
161 }
162
163 #[test]
164 fn derived_struct_resolves() {
165 let services = ServiceBuilder::new()
167 .with_instance(Config { port: 8080 })
168 .with_type::<DerivedDatabase>()
169 .build();
170 let db = services
172 .get::<DerivedDatabase>()
173 .expect("DerivedDatabase should resolve");
174 assert_eq!(db.config.port, 8080);
176 }
177
178 #[test]
179 fn unit_struct_resolves() {
180 let services = ServiceBuilder::new().with_type::<UnitService>().build();
182 let result = services.get::<UnitService>();
184 assert!(result.is_ok());
186 }
187
188 #[test]
189 fn mixed_default_fields_resolve() {
190 let services = ServiceBuilder::new()
192 .with_instance(Config { port: 8080 })
193 .with_type::<MixedService>()
194 .build();
195 let svc = services
197 .get::<MixedService>()
198 .expect("MixedService should resolve");
199 assert_eq!(svc.config.port, 8080);
201 assert_eq!(svc.port, 0);
202 }
203}