Skip to main content

rust_api/
di.rs

1//! Dependency Injection Container
2//!
3//! A simple, type-safe DI container that stores services as Arc-wrapped trait objects.
4//! Services can be registered and retrieved by type, with automatic Arc wrapping.
5
6use std::any::{Any, TypeId};
7use std::collections::HashMap;
8use std::sync::Arc;
9
10/// Trait that all injectable services must implement
11pub trait Injectable: Send + Sync + 'static {}
12
13/// Type-erased service storage using Any
14type ServiceBox = Arc<dyn Any + Send + Sync>;
15
16/// Dependency injection container
17///
18/// Stores services as Arc-wrapped values and provides type-safe retrieval.
19/// Services are singletons - only one instance exists per type.
20///
21/// # Example
22///
23/// ```ignore
24/// let mut container = Container::new();
25/// container.register(Arc::new(DatabaseService::new()));
26///
27/// let db: Arc<DatabaseService> = container.resolve().unwrap();
28/// ```
29#[derive(Clone, Default)]
30pub struct Container {
31    services: HashMap<TypeId, ServiceBox>,
32}
33
34impl Container {
35    /// Create a new empty container
36    pub fn new() -> Self {
37        Self {
38            services: HashMap::new(),
39        }
40    }
41
42    /// Register a service in the container
43    ///
44    /// The service must be wrapped in an Arc. If a service of this type
45    /// already exists, it will be replaced.
46    ///
47    /// # Example
48    ///
49    /// ```ignore
50    /// container.register(Arc::new(MyService::new()));
51    /// ```
52    pub fn register<T: Injectable>(&mut self, service: Arc<T>) {
53        let type_id = self.get_type_id::<T>();
54        self.insert_service(type_id, service);
55    }
56
57    //get the TypeId for a given type T
58    fn get_type_id<T: Injectable>(&self) -> TypeId {
59        TypeId::of::<T>()
60    }
61
62    //insert a service into the storage map
63    fn insert_service<T: Injectable>(&mut self, type_id: TypeId, service: Arc<T>) {
64        self.services.insert(type_id, service as ServiceBox);
65    }
66
67    /// Register a service from a constructor function
68    ///
69    /// This is a convenience method that creates the Arc for you.
70    ///
71    /// # Example
72    ///
73    /// ```ignore
74    /// container.register_factory(|| MyService::new());
75    /// ```
76    pub fn register_factory<T: Injectable, F>(&mut self, factory: F)
77    where
78        F: FnOnce() -> T,
79    {
80        let service = self.create_service(factory);
81        self.register(service);
82    }
83
84    //create a service instance from a factory function
85    fn create_service<T: Injectable, F>(& self, factory: F) -> Arc<T>
86    where
87        F: FnOnce() -> T,
88    {
89        Arc::new(factory())
90    }
91
92    /// Resolve a service from the container
93    ///
94    /// Returns None if the service hasn't been registered.
95    ///
96    /// # Example
97    ///
98    /// ```ignore
99    /// let service: Arc<MyService> = container.resolve().unwrap();
100    /// ```
101    pub fn resolve<T: Injectable>(&self) -> Option<Arc<T>> {
102        let type_id = self.get_type_id::<T>();
103        self.lookup_service(type_id)
104    }
105
106    //lookup a service by TypeId and downcast it
107    fn lookup_service<T: Injectable>(&self, type_id: TypeId) -> Option<Arc<T>> {
108        self.services
109            .get(&type_id)
110            .and_then(|boxed| self.downcast_service(boxed))
111    }
112
113    //downcast a type-erased service to the concrete type
114    fn downcast_service<T: Injectable>(&self, boxed: &ServiceBox) -> Option<Arc<T>> {
115        boxed.clone().downcast::<T>().ok()
116    }
117
118    /// Resolve a service or panic if not found
119    ///
120    /// # Panics
121    ///
122    /// Panics if the service hasn't been registered.
123    pub fn resolve_or_panic<T: Injectable>(&self) -> Arc<T> {
124        self.resolve()
125            .unwrap_or_else(|| panic!("Service {} not registered", std::any::type_name::<T>()))
126    }
127
128    /// Check if a service is registered
129    pub fn contains<T: Injectable>(&self) -> bool {
130        let type_id = TypeId::of::<T>();
131        self.services.contains_key(&type_id)
132    }
133
134    /// Get the number of registered services
135    pub fn len(&self) -> usize {
136        self.services.len()
137    }
138
139    /// Check if the container is empty
140    pub fn is_empty(&self) -> bool {
141        self.services.is_empty()
142    }
143
144    /// Clear all services from the container
145    pub fn clear(&mut self) {
146        self.services.clear();
147    }
148}
149
150#[cfg(test)]
151mod tests {
152    use super::*;
153
154    struct MockDatabase {
155        connection_string: String,
156    }
157
158    impl Injectable for MockDatabase {}
159
160    impl MockDatabase {
161        fn new(conn: &str) -> Self {
162            Self {
163                connection_string: conn.to_string(),
164            }
165        }
166    }
167
168    struct MockUserService {
169        db: Arc<MockDatabase>,
170    }
171
172    impl Injectable for MockUserService {}
173
174    impl MockUserService {
175        fn new(db: Arc<MockDatabase>) -> Self {
176            Self { db }
177        }
178    }
179
180    #[test]
181    fn test_register_and_resolve() {
182        let mut container = Container::new();
183        let db = Arc::new(MockDatabase::new("postgres://localhost"));
184
185        container.register(db.clone());
186
187        let resolved: Arc<MockDatabase> = container.resolve().unwrap();
188        assert_eq!(resolved.connection_string, "postgres://localhost");
189    }
190
191    #[test]
192    fn test_register_factory() {
193        let mut container = Container::new();
194
195        container.register_factory(|| MockDatabase::new("sqlite::memory"));
196
197        let resolved: Arc<MockDatabase> = container.resolve().unwrap();
198        assert_eq!(resolved.connection_string, "sqlite::memory");
199    }
200
201    #[test]
202    fn test_resolve_missing_service() {
203        let container = Container::new();
204        let result: Option<Arc<MockDatabase>> = container.resolve();
205        assert!(result.is_none());
206    }
207
208    #[test]
209    #[should_panic(expected = "Service")]
210    fn test_resolve_or_panic() {
211        let container = Container::new();
212        let _: Arc<MockDatabase> = container.resolve_or_panic();
213    }
214
215    #[test]
216    fn test_dependency_chain() {
217        let mut container = Container::new();
218
219        // Register database first
220        let db = Arc::new(MockDatabase::new("postgres://localhost"));
221        container.register(db.clone());
222
223        // Then register service that depends on it
224        let user_service = Arc::new(MockUserService::new(db));
225        container.register(user_service);
226
227        // Resolve both
228        let resolved_db: Arc<MockDatabase> = container.resolve().unwrap();
229        let resolved_service: Arc<MockUserService> = container.resolve().unwrap();
230
231        assert_eq!(resolved_db.connection_string, "postgres://localhost");
232        assert_eq!(
233            resolved_service.db.connection_string,
234            "postgres://localhost"
235        );
236    }
237
238    #[test]
239    fn test_contains() {
240        let mut container = Container::new();
241        assert!(!container.contains::<MockDatabase>());
242
243        container.register_factory(|| MockDatabase::new("test"));
244        assert!(container.contains::<MockDatabase>());
245    }
246
247    #[test]
248    fn test_len_and_clear() {
249        let mut container = Container::new();
250        assert_eq!(container.len(), 0);
251        assert!(container.is_empty());
252
253        container.register_factory(|| MockDatabase::new("test"));
254        assert_eq!(container.len(), 1);
255        assert!(!container.is_empty());
256
257        container.clear();
258        assert_eq!(container.len(), 0);
259        assert!(container.is_empty());
260    }
261}