1use std::{any::TypeId, collections::BTreeMap, fmt::Debug};
2
3use inventory::{Collect, Registry};
4use thiserror::Error;
5
6pub trait Factory<T: ?Sized> {
11 fn create(&self) -> Box<T>;
12}
13
14#[derive(Debug, Error)]
16pub enum FactoryError {
17 #[error("factory with ID '{0}' not found")]
19 FactoryNotFound(String),
20
21 #[error("empty ID provided without fallback")]
23 EmptyIdNoFallback,
24
25 #[error("no factories available")]
27 NoFactoriesAvailable,
28}
29
30#[derive(Debug, Clone, Copy, PartialEq, Eq)]
32pub enum FactoryFallback {
33 First,
35
36 Last,
38
39 NoFallback,
41}
42
43pub struct SimpleFactory<T: ?Sized + 'static>(
126 BTreeMap<&'static str, &'static (dyn Factory<T> + Sync)>,
127);
128
129impl<T> SimpleFactory<T>
130where
131 T: ?Sized + 'static,
132{
133 pub fn create(
150 &self,
151 id: impl AsRef<str>,
152 strategy: FactoryFallback,
153 ) -> Result<Box<T>, FactoryError> {
154 let id = id.as_ref();
155 if !id.is_empty() {
156 return if let Some(factory) = self.0.get(id) {
157 Ok(factory.create())
158 } else {
159 Err(FactoryError::FactoryNotFound(id.to_string()))
160 };
161 }
162
163 match strategy {
164 FactoryFallback::First => {
165 if let Some((_, factory)) = self.0.first_key_value() {
166 return Ok(factory.create());
167 }
168 }
169 FactoryFallback::Last => {
170 if let Some((_, factory)) = self.0.last_key_value() {
171 return Ok(factory.create());
172 }
173 }
174 FactoryFallback::NoFallback => return Err(FactoryError::EmptyIdNoFallback),
175 }
176
177 Err(FactoryError::NoFactoriesAvailable)
178 }
179}
180
181pub struct FactoryRegistry<T>
186where
187 T: ?Sized + 'static,
188{
189 id: &'static str,
194
195 factory: &'static (dyn Factory<T> + Sync),
200
201 type_id: TypeId,
206}
207
208impl<T> Collect for FactoryRegistry<T>
209where
210 T: ?Sized + 'static,
211{
212 fn registry() -> &'static Registry {
213 static REGISTRY: Registry = Registry::new();
214
215 ®ISTRY
216 }
217}
218
219impl<T> FactoryRegistry<T>
220where
221 T: ?Sized + 'static,
222{
223 #[inline]
229 pub const fn new(id: &'static str, factory: &'static (dyn Factory<T> + Sync)) -> Self {
230 Self {
231 id,
232 factory,
233 type_id: TypeId::of::<T>(),
234 }
235 }
236
237 pub fn simple_factory() -> SimpleFactory<T> {
248 let type_id = TypeId::of::<T>();
249 let factories = inventory::iter::<Self>()
250 .filter_map(|reg| (type_id == reg.type_id).then_some((reg.id, reg.factory)))
251 .collect();
252
253 SimpleFactory(factories)
254 }
255}
256
257#[macro_export]
261macro_rules! register_factory {
262 ($product:ty, $id:literal, $implement:ty) => {
263 $crate::const_assert!(!$id.is_empty());
264 $crate::assert_impl_one!($implement: Default);
265
266 const _: () = {
267 struct ConcreteFactory;
268
269 impl $crate::Factory<$product> for ConcreteFactory {
270 fn create(&self) -> Box<$product> {
271 Box::<$implement>::default()
272 }
273 }
274
275 $crate::submit! {
276 $crate::FactoryRegistry::new(
277 $id,
278 &ConcreteFactory as &'static (dyn $crate::Factory<$product> + Sync),
279 )
280 }
281 };
282 };
283}
284
285#[cfg(test)]
286mod tests {
287 use super::*;
288
289 trait TestProduct {
291 fn get_value(&self) -> &str;
292 }
293
294 struct ProductA {
295 value: String,
296 }
297
298 impl ProductA {
299 #[allow(dead_code)]
300 fn new(value: &str) -> Self {
301 Self {
302 value: value.to_string(),
303 }
304 }
305 }
306
307 impl TestProduct for ProductA {
308 fn get_value(&self) -> &str {
309 &self.value
310 }
311 }
312
313 impl Default for ProductA {
314 fn default() -> Self {
315 Self {
316 value: "default_a".to_string(),
317 }
318 }
319 }
320
321 struct ProductB {
322 value: String,
323 }
324
325 impl ProductB {
326 #[allow(dead_code)]
327 fn new(value: &str) -> Self {
328 Self {
329 value: value.to_string(),
330 }
331 }
332 }
333
334 impl TestProduct for ProductB {
335 fn get_value(&self) -> &str {
336 &self.value
337 }
338 }
339
340 impl Default for ProductB {
341 fn default() -> Self {
342 Self {
343 value: "default_b".to_string(),
344 }
345 }
346 }
347
348 register_factory!(dyn TestProduct, "product_a", ProductA);
350 register_factory!(dyn TestProduct, "product_b", ProductB);
351
352 #[test]
353 fn test_factory_registration() {
354 let factory = FactoryRegistry::<dyn TestProduct>::simple_factory();
355
356 let result_a = factory.create("product_a", FactoryFallback::NoFallback);
358 assert!(result_a.is_ok(), "product_a factory should exist");
359
360 let result_b = factory.create("product_b", FactoryFallback::NoFallback);
361 assert!(result_b.is_ok(), "product_b factory should exist");
362 }
363
364 #[test]
365 fn test_factory_creation() {
366 let factory = FactoryRegistry::<dyn TestProduct>::simple_factory();
367
368 let result = factory.create("product_a", FactoryFallback::NoFallback);
370 assert!(result.is_ok());
371
372 let product = result.unwrap();
373 assert_eq!(product.get_value(), "default_a");
374
375 let result = factory.create("product_b", FactoryFallback::NoFallback);
377 assert!(result.is_ok());
378
379 let product = result.unwrap();
380 assert_eq!(product.get_value(), "default_b");
381 }
382
383 #[test]
384 fn test_factory_error_cases() {
385 let factory = FactoryRegistry::<dyn TestProduct>::simple_factory();
386
387 let result = factory.create("non_existent", FactoryFallback::NoFallback);
389 assert!(result.is_err());
390
391 if let Err(FactoryError::FactoryNotFound(id)) = result {
392 assert_eq!(id, "non_existent");
393 } else {
394 panic!("Expected FactoryNotFound error");
395 }
396
397 let result = factory.create("", FactoryFallback::NoFallback);
399 assert!(result.is_err());
400
401 if let Err(FactoryError::EmptyIdNoFallback) = result {
402 } else {
404 panic!("Expected EmptyIdNoFallback error");
405 }
406 }
407
408 #[test]
409 fn test_factory_fallback_first() {
410 let factory = FactoryRegistry::<dyn TestProduct>::simple_factory();
411
412 let result = factory.create("", FactoryFallback::First);
414 match result {
417 Ok(_product) => {
418 }
420 Err(FactoryError::NoFactoriesAvailable) => {
421 }
423 Err(e) => {
424 panic!("Unexpected error: {:?}", e);
426 }
427 }
428
429 let result = factory.create("invalid_id", FactoryFallback::First);
431 match result {
432 Ok(_product) => {
433 panic!("Expected FactoryNotFound for invalid ID");
434 }
435 Err(FactoryError::FactoryNotFound(id)) => {
436 assert_eq!(id, "invalid_id");
437 }
438 Err(e) => {
439 panic!("Unexpected error: {:?}", e);
441 }
442 }
443 }
444
445 #[test]
446 fn test_factory_fallback_last() {
447 let factory = FactoryRegistry::<dyn TestProduct>::simple_factory();
448
449 let result = factory.create("", FactoryFallback::Last);
451 match result {
454 Ok(_product) => {
455 }
457 Err(FactoryError::NoFactoriesAvailable) => {
458 }
460 Err(e) => {
461 panic!("Unexpected error: {:?}", e);
463 }
464 }
465
466 let result = factory.create("invalid_id", FactoryFallback::Last);
468 match result {
469 Ok(_product) => {
470 panic!("Expected FactoryNotFound for invalid ID");
471 }
472 Err(FactoryError::FactoryNotFound(id)) => {
473 assert_eq!(id, "invalid_id");
474 }
475 Err(e) => {
476 panic!("Unexpected error: {:?}", e);
478 }
479 }
480 }
481
482 #[test]
483 fn test_factory_no_factories_available() {
484 trait EmptyProduct {
487 #[allow(dead_code)]
488 fn dummy(&self);
489 }
490
491 let factory = FactoryRegistry::<dyn EmptyProduct>::simple_factory();
492
493 let result = factory.create("", FactoryFallback::First);
495 assert!(result.is_err());
496
497 if let Err(FactoryError::NoFactoriesAvailable) = result {
498 } else {
500 panic!("Expected NoFactoriesAvailable error");
501 }
502
503 let result = factory.create("", FactoryFallback::Last);
504 assert!(result.is_err());
505
506 if let Err(FactoryError::NoFactoriesAvailable) = result {
507 } else {
509 panic!("Expected NoFactoriesAvailable error");
510 }
511 }
512
513 #[test]
514 fn test_factory_registry_new() {
515 struct TestFactory;
517
518 impl Factory<String> for TestFactory {
519 fn create(&self) -> Box<String> {
520 Box::new("test".to_string())
521 }
522 }
523
524 let factory = &TestFactory as &'static (dyn Factory<String> + Sync);
525 let registry = FactoryRegistry::new("test_id", factory);
526
527 assert_eq!(registry.id, "test_id");
528 assert_eq!(registry.type_id, TypeId::of::<String>());
529 }
530
531 #[test]
532 fn test_factory_error_display() {
533 let error = FactoryError::FactoryNotFound("test_id".to_string());
535 assert_eq!(format!("{}", error), "factory with ID 'test_id' not found");
536
537 let error = FactoryError::EmptyIdNoFallback;
538 assert_eq!(format!("{}", error), "empty ID provided without fallback");
539
540 let error = FactoryError::NoFactoriesAvailable;
541 assert_eq!(format!("{}", error), "no factories available");
542 }
543
544 #[test]
545 fn test_factory_fallback_debug() {
546 assert_eq!(format!("{:?}", FactoryFallback::First), "First");
548 assert_eq!(format!("{:?}", FactoryFallback::Last), "Last");
549 assert_eq!(format!("{:?}", FactoryFallback::NoFallback), "NoFallback");
550 }
551
552 #[test]
553 fn test_factory_fallback_eq() {
554 assert_eq!(FactoryFallback::First, FactoryFallback::First);
556 assert_eq!(FactoryFallback::Last, FactoryFallback::Last);
557 assert_eq!(FactoryFallback::NoFallback, FactoryFallback::NoFallback);
558 assert_ne!(FactoryFallback::First, FactoryFallback::Last);
559 assert_ne!(FactoryFallback::First, FactoryFallback::NoFallback);
560 }
561
562 #[test]
563 fn test_simple_factory_debug() {
564 let factory = FactoryRegistry::<dyn TestProduct>::simple_factory();
568 let result = factory.create("product_a", FactoryFallback::NoFallback);
569 assert!(result.is_ok());
570 }
571}