1use std::{any::TypeId, collections::BTreeMap, fmt::Debug};
2
3use inventory::{Collect, Registry};
4use thiserror::Error;
5
6pub trait Factory<T: ?Sized> {
11 fn create(&self) -> Box<T>;
12}
13
14#[derive(Debug, Error)]
16pub enum FactoryError {
17 #[error("未找到 ID 为 '{0}' 的工厂")]
19 FactoryNotFound(String),
20
21 #[error("不允许回退时提供了空 ID")]
23 EmptyIdNoFallback,
24
25 #[error("没有可用的工厂")]
27 NoFactoriesAvailable,
28}
29
30#[derive(Debug, Clone, Copy, PartialEq, Eq)]
32pub enum FactoryFallback {
33 First,
35
36 Last,
38
39 NoFallback,
41}
42
43pub struct SimpleFactory<T: ?Sized + 'static>(
126 BTreeMap<&'static str, &'static (dyn Factory<T> + Sync)>,
127);
128
129impl<T> SimpleFactory<T>
130where
131 T: ?Sized + 'static,
132{
133 pub fn create<'a>(
150 &self,
151 id: &'a str,
152 strategy: FactoryFallback,
153 ) -> Result<(&'a str, Box<T>), FactoryError> {
154 if !id.is_empty() {
155 return if let Some(factory) = self.0.get(id) {
156 Ok((id, factory.create()))
157 } else {
158 Err(FactoryError::FactoryNotFound(id.to_string()))
159 };
160 }
161
162 match strategy {
163 FactoryFallback::First => {
164 if let Some((id, factory)) = self.0.first_key_value() {
165 return Ok((id, factory.create()));
166 }
167 }
168 FactoryFallback::Last => {
169 if let Some((id, factory)) = self.0.last_key_value() {
170 return Ok((id, factory.create()));
171 }
172 }
173 FactoryFallback::NoFallback => return Err(FactoryError::EmptyIdNoFallback),
174 }
175
176 Err(FactoryError::NoFactoriesAvailable)
177 }
178}
179
180pub struct FactoryRegistry<T>
185where
186 T: ?Sized + 'static,
187{
188 id: &'static str,
193
194 factory: &'static (dyn Factory<T> + Sync),
199
200 type_id: TypeId,
205}
206
207impl<T> Collect for FactoryRegistry<T>
208where
209 T: ?Sized + 'static,
210{
211 fn registry() -> &'static Registry {
212 static REGISTRY: Registry = Registry::new();
213
214 ®ISTRY
215 }
216}
217
218impl<T> FactoryRegistry<T>
219where
220 T: ?Sized + 'static,
221{
222 #[inline]
228 pub const fn new(id: &'static str, factory: &'static (dyn Factory<T> + Sync)) -> Self {
229 Self {
230 id,
231 factory,
232 type_id: TypeId::of::<T>(),
233 }
234 }
235
236 pub fn simple_factory() -> SimpleFactory<T> {
247 let type_id = TypeId::of::<T>();
248 let factories = inventory::iter::<Self>()
249 .filter_map(|reg| (type_id == reg.type_id).then_some((reg.id, reg.factory)))
250 .collect();
251
252 SimpleFactory(factories)
253 }
254}
255
256#[macro_export]
260macro_rules! register_factory {
261 ($product:ty, $id:literal, $implement:ty) => {
262 $crate::const_assert!(!$id.is_empty());
263 $crate::assert_impl_one!($implement: Default);
264
265 const _: () = {
266 struct ConcreteFactory;
267
268 impl $crate::Factory<$product> for ConcreteFactory {
269 fn create(&self) -> Box<$product> {
270 Box::<$implement>::default()
271 }
272 }
273
274 $crate::submit! {
275 $crate::FactoryRegistry::new(
276 $id,
277 &ConcreteFactory as &'static (dyn $crate::Factory<$product> + Sync),
278 )
279 }
280 };
281 };
282}
283
284#[cfg(test)]
285mod tests {
286 use super::*;
287
288 trait TestProduct {
290 fn get_value(&self) -> &str;
291 }
292
293 struct ProductA {
294 value: String,
295 }
296
297 impl ProductA {
298 #[allow(dead_code)]
299 fn new(value: &str) -> Self {
300 Self {
301 value: value.to_string(),
302 }
303 }
304 }
305
306 impl TestProduct for ProductA {
307 fn get_value(&self) -> &str {
308 &self.value
309 }
310 }
311
312 impl Default for ProductA {
313 fn default() -> Self {
314 Self {
315 value: "default_a".to_string(),
316 }
317 }
318 }
319
320 struct ProductB {
321 value: String,
322 }
323
324 impl ProductB {
325 #[allow(dead_code)]
326 fn new(value: &str) -> Self {
327 Self {
328 value: value.to_string(),
329 }
330 }
331 }
332
333 impl TestProduct for ProductB {
334 fn get_value(&self) -> &str {
335 &self.value
336 }
337 }
338
339 impl Default for ProductB {
340 fn default() -> Self {
341 Self {
342 value: "default_b".to_string(),
343 }
344 }
345 }
346
347 register_factory!(dyn TestProduct, "product_a", ProductA);
349 register_factory!(dyn TestProduct, "product_b", ProductB);
350
351 #[test]
352 fn test_factory_registration() {
353 let factory = FactoryRegistry::<dyn TestProduct>::simple_factory();
354
355 let result_a = factory.create("product_a", FactoryFallback::NoFallback);
357 assert!(result_a.is_ok(), "product_a factory should exist");
358
359 let result_b = factory.create("product_b", FactoryFallback::NoFallback);
360 assert!(result_b.is_ok(), "product_b factory should exist");
361 }
362
363 #[test]
364 fn test_factory_creation() {
365 let factory = FactoryRegistry::<dyn TestProduct>::simple_factory();
366
367 let result = factory.create("product_a", FactoryFallback::NoFallback);
369 assert!(result.is_ok());
370
371 let (id, product) = result.unwrap();
372 assert_eq!(id, "product_a");
373 assert_eq!(product.get_value(), "default_a");
374
375 let result = factory.create("product_b", FactoryFallback::NoFallback);
377 assert!(result.is_ok());
378
379 let (id, product) = result.unwrap();
380 assert_eq!(id, "product_b");
381 assert_eq!(product.get_value(), "default_b");
382 }
383
384 #[test]
385 fn test_factory_error_cases() {
386 let factory = FactoryRegistry::<dyn TestProduct>::simple_factory();
387
388 let result = factory.create("non_existent", FactoryFallback::NoFallback);
390 assert!(result.is_err());
391
392 if let Err(FactoryError::FactoryNotFound(id)) = result {
393 assert_eq!(id, "non_existent");
394 } else {
395 panic!("Expected FactoryNotFound error");
396 }
397
398 let result = factory.create("", FactoryFallback::NoFallback);
400 assert!(result.is_err());
401
402 if let Err(FactoryError::EmptyIdNoFallback) = result {
403 } else {
405 panic!("Expected EmptyIdNoFallback error");
406 }
407 }
408
409 #[test]
410 fn test_factory_fallback_first() {
411 let factory = FactoryRegistry::<dyn TestProduct>::simple_factory();
412
413 let result = factory.create("", FactoryFallback::First);
415 match result {
418 Ok((id, _)) => {
419 assert!(!id.is_empty());
421 }
422 Err(FactoryError::NoFactoriesAvailable) => {
423 }
425 Err(e) => {
426 panic!("Unexpected error: {:?}", e);
428 }
429 }
430
431 let result = factory.create("invalid_id", FactoryFallback::First);
433 match result {
434 Ok((id, _)) => {
435 assert!(!id.is_empty());
437 }
438 Err(FactoryError::FactoryNotFound(id)) => {
439 assert_eq!(id, "invalid_id");
441 }
442 Err(FactoryError::NoFactoriesAvailable) => {
443 }
445 Err(e) => {
446 panic!("Unexpected error: {:?}", e);
448 }
449 }
450 }
451
452 #[test]
453 fn test_factory_fallback_last() {
454 let factory = FactoryRegistry::<dyn TestProduct>::simple_factory();
455
456 let result = factory.create("", FactoryFallback::Last);
458 match result {
461 Ok((id, _)) => {
462 assert!(!id.is_empty());
464 }
465 Err(FactoryError::NoFactoriesAvailable) => {
466 }
468 Err(e) => {
469 panic!("Unexpected error: {:?}", e);
471 }
472 }
473
474 let result = factory.create("invalid_id", FactoryFallback::Last);
476 match result {
477 Ok((id, _)) => {
478 assert!(!id.is_empty());
480 }
481 Err(FactoryError::FactoryNotFound(id)) => {
482 assert_eq!(id, "invalid_id");
484 }
485 Err(FactoryError::NoFactoriesAvailable) => {
486 }
488 Err(e) => {
489 panic!("Unexpected error: {:?}", e);
491 }
492 }
493 }
494
495 #[test]
496 fn test_factory_no_factories_available() {
497 trait EmptyProduct {
500 #[allow(dead_code)]
501 fn dummy(&self);
502 }
503
504 let factory = FactoryRegistry::<dyn EmptyProduct>::simple_factory();
505
506 let result = factory.create("", FactoryFallback::First);
508 assert!(result.is_err());
509
510 if let Err(FactoryError::NoFactoriesAvailable) = result {
511 } else {
513 panic!("Expected NoFactoriesAvailable error");
514 }
515
516 let result = factory.create("", FactoryFallback::Last);
517 assert!(result.is_err());
518
519 if let Err(FactoryError::NoFactoriesAvailable) = result {
520 } else {
522 panic!("Expected NoFactoriesAvailable error");
523 }
524 }
525
526 #[test]
527 fn test_factory_registry_new() {
528 struct TestFactory;
530
531 impl Factory<String> for TestFactory {
532 fn create(&self) -> Box<String> {
533 Box::new("test".to_string())
534 }
535 }
536
537 let factory = &TestFactory as &'static (dyn Factory<String> + Sync);
538 let registry = FactoryRegistry::new("test_id", factory);
539
540 assert_eq!(registry.id, "test_id");
541 assert_eq!(registry.type_id, TypeId::of::<String>());
542 }
543
544 #[test]
545 fn test_factory_error_display() {
546 let error = FactoryError::FactoryNotFound("test_id".to_string());
548 assert_eq!(format!("{}", error), "未找到 ID 为 'test_id' 的工厂");
549
550 let error = FactoryError::EmptyIdNoFallback;
551 assert_eq!(format!("{}", error), "不允许回退时提供了空 ID");
552
553 let error = FactoryError::NoFactoriesAvailable;
554 assert_eq!(format!("{}", error), "没有可用的工厂");
555 }
556
557 #[test]
558 fn test_factory_fallback_debug() {
559 assert_eq!(format!("{:?}", FactoryFallback::First), "First");
561 assert_eq!(format!("{:?}", FactoryFallback::Last), "Last");
562 assert_eq!(format!("{:?}", FactoryFallback::NoFallback), "NoFallback");
563 }
564
565 #[test]
566 fn test_factory_fallback_eq() {
567 assert_eq!(FactoryFallback::First, FactoryFallback::First);
569 assert_eq!(FactoryFallback::Last, FactoryFallback::Last);
570 assert_eq!(FactoryFallback::NoFallback, FactoryFallback::NoFallback);
571 assert_ne!(FactoryFallback::First, FactoryFallback::Last);
572 assert_ne!(FactoryFallback::First, FactoryFallback::NoFallback);
573 }
574
575 #[test]
576 fn test_simple_factory_debug() {
577 let factory = FactoryRegistry::<dyn TestProduct>::simple_factory();
581 let result = factory.create("product_a", FactoryFallback::NoFallback);
582 assert!(result.is_ok());
583 }
584}