Skip to main content

studiole_di/
service_provider.rs

1//! Service resolution.
2
3use crate::prelude::*;
4
5/// Resolve registered services.
6#[derive(Clone)]
7pub struct ServiceProvider {
8    /// Shared reference to the service registry.
9    pub(crate) registry: Arc<ServiceRegistry>,
10}
11
12impl ServiceProvider {
13    /// Resolve a concrete type.
14    pub fn get<T: Send + Sync + 'static>(&self) -> Result<Arc<T>, Report<ResolveError>> {
15        let type_name = type_name::<T>();
16        trace!(type_name, "Resolving service");
17        let type_id = TypeId::of::<T>();
18
19        if let Some(dynamic) = self.get_cached(type_id) {
20            return Ok(dynamic.expect_downcast::<T>());
21        }
22
23        let registration = self.get_registration(type_id, type_name)?;
24        #[cfg(feature = "async")]
25        if registration.is_async {
26            return Err(Report::new(ResolveError::Async)).attach("type", type_name);
27        }
28        let dynamic = (registration.factory)(self)?;
29        self.cache_if_singleton(type_id, registration.scope, &dynamic);
30        Ok(dynamic.expect_downcast::<T>())
31    }
32
33    /// Look up a cached instance by type.
34    pub(crate) fn get_cached(&self, type_id: TypeId) -> Option<Arc<dyn Any + Send + Sync>> {
35        let instances = self
36            .registry
37            .instances
38            .lock()
39            .expect("should be able to lock instances");
40        instances.get(&type_id).map(Arc::clone)
41    }
42
43    /// Look up a registration by type.
44    pub(crate) fn get_registration(
45        &self,
46        type_id: TypeId,
47        type_name: &'static str,
48    ) -> Result<&Registration, Report<ResolveError>> {
49        self.registry
50            .factories
51            .get(&type_id)
52            .ok_or_else(|| Report::new(ResolveError::NotFound))
53            .attach("type", type_name)
54    }
55
56    /// Cache an instance if the registration is a singleton.
57    pub(crate) fn cache_if_singleton(
58        &self,
59        type_id: TypeId,
60        scope: Scope,
61        dynamic: &Arc<dyn Any + Send + Sync>,
62    ) {
63        if scope == Scope::Singleton {
64            self.registry
65                .instances
66                .lock()
67                .expect("should be able to lock instances")
68                .insert(type_id, Arc::clone(dynamic));
69        }
70    }
71}
72
73/// Errors returned when resolving a service.
74#[derive(Clone, Copy, Debug, Eq, Error, PartialEq)]
75pub enum ResolveError {
76    /// No service was registered for the requested type.
77    #[error("Service not registered")]
78    NotFound,
79    /// The factory function failed during service construction.
80    #[error("Factory failed to construct service")]
81    Factory,
82    /// The service requires async resolution but was called synchronously.
83    #[cfg(feature = "async")]
84    #[error("Service requires async resolution")]
85    Async,
86}
87
88#[cfg(test)]
89mod tests {
90    use super::*;
91
92    #[test]
93    fn singleton_shares_state() {
94        // Arrange
95        let services = ServiceBuilder::new().with_type::<MemoryCache>().build();
96
97        // Act
98        let first = services.get::<MemoryCache>().expect("should resolve");
99        first.set("key", "hello");
100        let second = services.get::<MemoryCache>().expect("should resolve");
101
102        // Assert
103        assert_eq!(second.get("key"), Some(String::from("hello")));
104    }
105
106    #[test]
107    fn transient_does_not_share_state() {
108        // Arrange
109        let services = ServiceBuilder::new()
110            .with_type_transient::<MemoryCache>()
111            .build();
112
113        // Act
114        let first = services.get::<MemoryCache>().expect("should resolve");
115        first.set("key", "hello");
116        let second = services.get::<MemoryCache>().expect("should resolve");
117
118        // Assert
119        assert_eq!(second.get("key"), None);
120    }
121
122    #[test]
123    fn unregistered_type_returns_not_found() {
124        // Arrange
125        let services = ServiceBuilder::new().build();
126
127        // Act
128        let result = services.get::<Config>();
129
130        // Assert
131        assert!(result.is_err());
132    }
133
134    #[test]
135    fn resolve_instance() {
136        // Arrange
137        let services = ServiceBuilder::new()
138            .with_instance(Config { port: 3000 })
139            .build();
140
141        // Act
142        let config = services.get::<Config>().expect("should resolve");
143
144        // Assert
145        assert_eq!(config.port, 3000);
146    }
147
148    #[test]
149    fn cloned_provider_shares_singleton() {
150        // Arrange
151        let services = ServiceBuilder::new().with_type::<MemoryCache>().build();
152
153        // Act
154        let first = services.get::<MemoryCache>().expect("should resolve");
155        first.set("key", "hello");
156        let cloned = services.clone();
157        let second = cloned.get::<MemoryCache>().expect("should resolve");
158
159        // Assert
160        assert_eq!(second.get("key"), Some(String::from("hello")));
161    }
162
163    #[test]
164    fn derived_struct_resolves() {
165        // Arrange
166        let services = ServiceBuilder::new()
167            .with_instance(Config { port: 8080 })
168            .with_type::<DerivedDatabase>()
169            .build();
170        // Act
171        let db = services
172            .get::<DerivedDatabase>()
173            .expect("DerivedDatabase should resolve");
174        // Assert
175        assert_eq!(db.config.port, 8080);
176    }
177
178    #[test]
179    fn unit_struct_resolves() {
180        // Arrange
181        let services = ServiceBuilder::new().with_type::<UnitService>().build();
182        // Act
183        let result = services.get::<UnitService>();
184        // Assert
185        assert!(result.is_ok());
186    }
187
188    #[test]
189    fn mixed_default_fields_resolve() {
190        // Arrange
191        let services = ServiceBuilder::new()
192            .with_instance(Config { port: 8080 })
193            .with_type::<MixedService>()
194            .build();
195        // Act
196        let svc = services
197            .get::<MixedService>()
198            .expect("MixedService should resolve");
199        // Assert
200        assert_eq!(svc.config.port, 8080);
201        assert_eq!(svc.port, 0);
202    }
203}