1use std::any::{Any, TypeId};
7use std::collections::HashMap;
8use std::sync::Arc;
9
10pub trait Injectable: Send + Sync + 'static {}
12
13type ServiceBox = Arc<dyn Any + Send + Sync>;
15
16#[derive(Clone, Default)]
30pub struct Container {
31 services: HashMap<TypeId, ServiceBox>,
32}
33
34impl Container {
35 pub fn new() -> Self {
37 Self {
38 services: HashMap::new(),
39 }
40 }
41
42 pub fn register<T: Injectable>(&mut self, service: Arc<T>) {
53 let type_id = self.get_type_id::<T>();
54 self.insert_service(type_id, service);
55 }
56
57 fn get_type_id<T: Injectable>(&self) -> TypeId {
59 TypeId::of::<T>()
60 }
61
62 fn insert_service<T: Injectable>(&mut self, type_id: TypeId, service: Arc<T>) {
64 self.services.insert(type_id, service as ServiceBox);
65 }
66
67 pub fn register_factory<T: Injectable, F>(&mut self, factory: F)
77 where
78 F: FnOnce() -> T,
79 {
80 let service = self.create_service(factory);
81 self.register(service);
82 }
83
84 fn create_service<T: Injectable, F>(& self, factory: F) -> Arc<T>
86 where
87 F: FnOnce() -> T,
88 {
89 Arc::new(factory())
90 }
91
92 pub fn resolve<T: Injectable>(&self) -> Option<Arc<T>> {
102 let type_id = self.get_type_id::<T>();
103 self.lookup_service(type_id)
104 }
105
106 fn lookup_service<T: Injectable>(&self, type_id: TypeId) -> Option<Arc<T>> {
108 self.services
109 .get(&type_id)
110 .and_then(|boxed| self.downcast_service(boxed))
111 }
112
113 fn downcast_service<T: Injectable>(&self, boxed: &ServiceBox) -> Option<Arc<T>> {
115 boxed.clone().downcast::<T>().ok()
116 }
117
118 pub fn resolve_or_panic<T: Injectable>(&self) -> Arc<T> {
124 self.resolve()
125 .unwrap_or_else(|| panic!("Service {} not registered", std::any::type_name::<T>()))
126 }
127
128 pub fn contains<T: Injectable>(&self) -> bool {
130 let type_id = TypeId::of::<T>();
131 self.services.contains_key(&type_id)
132 }
133
134 pub fn len(&self) -> usize {
136 self.services.len()
137 }
138
139 pub fn is_empty(&self) -> bool {
141 self.services.is_empty()
142 }
143
144 pub fn clear(&mut self) {
146 self.services.clear();
147 }
148}
149
150#[cfg(test)]
151mod tests {
152 use super::*;
153
154 struct MockDatabase {
155 connection_string: String,
156 }
157
158 impl Injectable for MockDatabase {}
159
160 impl MockDatabase {
161 fn new(conn: &str) -> Self {
162 Self {
163 connection_string: conn.to_string(),
164 }
165 }
166 }
167
168 struct MockUserService {
169 db: Arc<MockDatabase>,
170 }
171
172 impl Injectable for MockUserService {}
173
174 impl MockUserService {
175 fn new(db: Arc<MockDatabase>) -> Self {
176 Self { db }
177 }
178 }
179
180 #[test]
181 fn test_register_and_resolve() {
182 let mut container = Container::new();
183 let db = Arc::new(MockDatabase::new("postgres://localhost"));
184
185 container.register(db.clone());
186
187 let resolved: Arc<MockDatabase> = container.resolve().unwrap();
188 assert_eq!(resolved.connection_string, "postgres://localhost");
189 }
190
191 #[test]
192 fn test_register_factory() {
193 let mut container = Container::new();
194
195 container.register_factory(|| MockDatabase::new("sqlite::memory"));
196
197 let resolved: Arc<MockDatabase> = container.resolve().unwrap();
198 assert_eq!(resolved.connection_string, "sqlite::memory");
199 }
200
201 #[test]
202 fn test_resolve_missing_service() {
203 let container = Container::new();
204 let result: Option<Arc<MockDatabase>> = container.resolve();
205 assert!(result.is_none());
206 }
207
208 #[test]
209 #[should_panic(expected = "Service")]
210 fn test_resolve_or_panic() {
211 let container = Container::new();
212 let _: Arc<MockDatabase> = container.resolve_or_panic();
213 }
214
215 #[test]
216 fn test_dependency_chain() {
217 let mut container = Container::new();
218
219 let db = Arc::new(MockDatabase::new("postgres://localhost"));
221 container.register(db.clone());
222
223 let user_service = Arc::new(MockUserService::new(db));
225 container.register(user_service);
226
227 let resolved_db: Arc<MockDatabase> = container.resolve().unwrap();
229 let resolved_service: Arc<MockUserService> = container.resolve().unwrap();
230
231 assert_eq!(resolved_db.connection_string, "postgres://localhost");
232 assert_eq!(
233 resolved_service.db.connection_string,
234 "postgres://localhost"
235 );
236 }
237
238 #[test]
239 fn test_contains() {
240 let mut container = Container::new();
241 assert!(!container.contains::<MockDatabase>());
242
243 container.register_factory(|| MockDatabase::new("test"));
244 assert!(container.contains::<MockDatabase>());
245 }
246
247 #[test]
248 fn test_len_and_clear() {
249 let mut container = Container::new();
250 assert_eq!(container.len(), 0);
251 assert!(container.is_empty());
252
253 container.register_factory(|| MockDatabase::new("test"));
254 assert_eq!(container.len(), 1);
255 assert!(!container.is_empty());
256
257 container.clear();
258 assert_eq!(container.len(), 0);
259 assert!(container.is_empty());
260 }
261}