Skip to main content

rust_pattern_components/
factory.rs

1use std::{any::TypeId, collections::BTreeMap, fmt::Debug};
2
3use inventory::{Collect, Registry};
4use thiserror::Error;
5
6/// 用于创建类型 `T` 实例的工厂 trait。
7///
8/// 实现此 trait 的类型可以创建目标类型的装箱实例。
9/// 类型 `T` 必须是 `Send + Sync` 并且可以是非固定大小类型。
10pub trait Factory<T: ?Sized> {
11    fn create(&self) -> Box<T>;
12}
13
14/// 通过工厂创建对象时可能发生的错误。
15#[derive(Debug, Error)]
16pub enum FactoryError {
17    /// 未找到指定的工厂。
18    #[error("未找到 ID 为 '{0}' 的工厂")]
19    FactoryNotFound(String),
20
21    /// 不允许回退时提供了空 ID。
22    #[error("不允许回退时提供了空 ID")]
23    EmptyIdNoFallback,
24
25    /// 请求的产品类型没有可用的工厂。
26    #[error("没有可用的工厂")]
27    NoFactoriesAvailable,
28}
29
30/// 当找不到指定 ID 的工厂时的处理策略。
31#[derive(Debug, Clone, Copy, PartialEq, Eq)]
32pub enum FactoryFallback {
33    /// 如果找不到指定的工厂,则使用集合中的第一个工厂。
34    First,
35
36    /// 如果找不到指定的工厂,则使用集合中的最后一个工厂。
37    Last,
38
39    /// ID 为空时不回退,直接返回错误。
40    NoFallback,
41}
42
43/// 用于创建类型 `T` 实例的工厂集合。
44///
45/// 此结构体包含一个从工厂 ID 到工厂实例的映射,这些工厂实例可以创建
46/// 目标类型 `T` 的装箱实例。工厂在编译时使用 `inventory` crate 注册,
47/// 可以通过 `FactoryRegistry::factories()` 检索。
48///
49/// 类型 `T` 可以是非固定大小类型(trait 对象),并且必须具有 `'static` 生命周期。
50/// 工厂存储为静态引用,允许它们在线程间共享。
51///
52/// # 示例
53///
54/// 基本用法:
55///
56/// ```rust
57/// use rust_patterns_components::{FactoryFallback, FactoryRegistry};
58///
59/// // 定义产品 trait
60/// trait Product {
61///     fn name(&self) -> &str;
62/// }
63///
64/// // 假设已经注册了工厂(通过 inventory 机制)
65/// // register_factory!(dyn Product, "product_a", ProductA);
66/// // register_factory!(dyn Product, "product_b", ProductB);
67///
68/// // 获取工厂实例
69/// let factory = FactoryRegistry::<dyn Product>::simple_factory();
70///
71/// // 通过 ID 创建特定产品
72/// match factory.create("product_a", FactoryFallback::NoFallback) {
73///     Ok((id, product)) => {
74///         println!("创建了产品: {}, ID: {}", product.name(), id);
75///     }
76///     Err(e) => {
77///         println!("创建失败: {}", e);
78///     }
79/// }
80///
81/// // 使用回退策略
82/// let result = factory.create("", FactoryFallback::First);
83/// // 当 ID 为空时,使用第一个可用的工厂
84///
85/// let result = factory.create("", FactoryFallback::Last);
86/// // 当 ID 为空时,使用最后一个可用的工厂
87/// ```
88///
89/// 错误处理:
90///
91/// ```rust
92/// use rust_patterns_components::{FactoryFallback, FactoryRegistry, FactoryError};
93///
94/// // 定义产品 trait
95/// trait Product {
96///     fn name(&self) -> &str;
97/// }
98///
99/// let factory = FactoryRegistry::<dyn Product>::simple_factory();
100///
101/// // 不存在的工厂 ID
102/// match factory.create("nonexistent", FactoryFallback::NoFallback) {
103///     Err(FactoryError::FactoryNotFound(id)) => {
104///         println!("未找到工厂: {}", id);
105///     }
106///     _ => {}
107/// }
108///
109/// // 空 ID 且无回退策略
110/// match factory.create("", FactoryFallback::NoFallback) {
111///     Err(FactoryError::EmptyIdNoFallback) => {
112///         println!("空 ID 且未指定回退策略");
113///     }
114///     _ => {}
115/// }
116///
117/// // 没有可用的工厂
118/// match factory.create("any", FactoryFallback::NoFallback) {
119///     Err(FactoryError::NoFactoriesAvailable) => {
120///         println!("没有可用的工厂");
121///     }
122///     _ => {}
123/// }
124/// ```
125pub struct SimpleFactory<T: ?Sized + 'static>(
126    BTreeMap<&'static str, &'static (dyn Factory<T> + Sync)>,
127);
128
129impl<T> SimpleFactory<T>
130where
131    T: ?Sized + 'static,
132{
133    /// 使用指定的回退策略通过工厂模式创建实例。
134    ///
135    /// 此函数通过 ID 查找工厂并使用它创建实例。
136    /// 如果 `id` 为空,行为取决于 `strategy`:
137    /// - `NoFallback`:返回错误
138    /// - `First`:使用集合中的第一个工厂
139    /// - `Last`:使用集合中的最后一个工厂
140    ///   如果 `id` 不为空但找不到工厂,行为由 `strategy` 决定。
141    ///
142    /// # 参数
143    /// * `id` - 要使用的工厂标识符,或空字符串表示默认
144    /// * `strategy` - 找不到指定 ID 的工厂时使用的策略
145    ///
146    /// # 返回值
147    /// * `Ok((&str, Box<T>))` - 成功时返回包含使用的工厂 ID 和创建的实例的元组
148    /// * `Err(FactoryError)` - 如果找不到工厂或没有可用的工厂则返回错误
149    pub fn create<'a>(
150        &self,
151        id: &'a str,
152        strategy: FactoryFallback,
153    ) -> Result<(&'a str, Box<T>), FactoryError> {
154        if !id.is_empty() {
155            return if let Some(factory) = self.0.get(id) {
156                Ok((id, factory.create()))
157            } else {
158                Err(FactoryError::FactoryNotFound(id.to_string()))
159            };
160        }
161
162        match strategy {
163            FactoryFallback::First => {
164                if let Some((id, factory)) = self.0.first_key_value() {
165                    return Ok((id, factory.create()));
166                }
167            }
168            FactoryFallback::Last => {
169                if let Some((id, factory)) = self.0.last_key_value() {
170                    return Ok((id, factory.create()));
171                }
172            }
173            FactoryFallback::NoFallback => return Err(FactoryError::EmptyIdNoFallback),
174        }
175
176        Err(FactoryError::NoFactoriesAvailable)
177    }
178}
179
180/// 工厂实现的注册表条目。
181///
182/// 存储工厂的元数据,包括其 ID、产品类型 ID 和工厂实例。
183/// 此类型与 `inventory` crate 一起用于编译时注册。
184pub struct FactoryRegistry<T>
185where
186    T: ?Sized + 'static,
187{
188    /// 此工厂的唯一标识符。
189    ///
190    /// 此 ID 用于在创建实例时查找工厂。
191    /// 它必须是静态字符串字面量,并且对于给定产品类型 `T` 在注册表中应该是唯一的。
192    id: &'static str,
193
194    /// 创建类型 `T` 实例的工厂实例。
195    ///
196    /// 这是一个静态引用,指向可以创建产品类型 `T` 的装箱实例的工厂实现。
197    /// 工厂必须是线程安全的(`Sync`)以允许在线程间共享。
198    factory: &'static (dyn Factory<T> + Sync),
199
200    /// 产品类型 `T` 的类型标识符。
201    ///
202    /// 此字段存储产品类型 `T` 的 `TypeId`,用于在从注册表检索时按产品类型过滤工厂。
203    /// 它确保只有正确产品类型的工厂包含在工厂集合中。
204    type_id: TypeId,
205}
206
207impl<T> Collect for FactoryRegistry<T>
208where
209    T: ?Sized + 'static,
210{
211    fn registry() -> &'static Registry {
212        static REGISTRY: Registry = Registry::new();
213
214        &REGISTRY
215    }
216}
217
218impl<T> FactoryRegistry<T>
219where
220    T: ?Sized + 'static,
221{
222    /// 创建一个新的工厂注册表条目。
223    ///
224    /// # 参数
225    /// * `id` - 此工厂的唯一标识符
226    /// * `factory` - 创建产品的工厂实例
227    #[inline]
228    pub const fn new(id: &'static str, factory: &'static (dyn Factory<T> + Sync)) -> Self {
229        Self {
230            id,
231            factory,
232            type_id: TypeId::of::<T>(),
233        }
234    }
235
236    /// 查找产品类型 `T` 的所有已注册工厂。
237    ///
238    /// 此函数扫描编译时工厂注册表,并返回一个 `SimpleFactory` 实例,
239    /// 该实例包含一个从工厂 ID 到工厂实例的映射,这些工厂实例生产类型 `T` 的实例。
240    /// 只包含为确切产品类型 `T` 注册的工厂。
241    ///
242    /// # 返回值
243    /// 一个 `SimpleFactory<T>` 实例,包装一个 `BTreeMap`,其中:
244    /// - 键是工厂的静态字符串标识符
245    /// - 值是实现 `Factory<T>` 的工厂实例的引用
246    pub fn simple_factory() -> SimpleFactory<T> {
247        let type_id = TypeId::of::<T>();
248        let factories = inventory::iter::<Self>()
249            .filter_map(|reg| (type_id == reg.type_id).then_some((reg.id, reg.factory)))
250            .collect();
251
252        SimpleFactory(factories)
253    }
254}
255
256/// 为产品类型注册工厂实现的宏。
257///
258/// 注册一个工厂,该工厂创建 `$implement` 的实例作为 `$product` trait 的实现。
259#[macro_export]
260macro_rules! register_factory {
261    ($product:ty, $id:literal, $implement:ty) => {
262        $crate::const_assert!(!$id.is_empty());
263        $crate::assert_impl_one!($implement: Default);
264
265        const _: () = {
266            struct ConcreteFactory;
267
268            impl $crate::Factory<$product> for ConcreteFactory {
269                fn create(&self) -> Box<$product> {
270                    Box::<$implement>::default()
271                }
272            }
273
274            $crate::submit! {
275                $crate::FactoryRegistry::new(
276                    $id,
277                    &ConcreteFactory as &'static (dyn $crate::Factory<$product> + Sync),
278                )
279            }
280        };
281    };
282}
283
284#[cfg(test)]
285mod tests {
286    use super::*;
287
288    // 测试用的 trait 和实现
289    trait TestProduct {
290        fn get_value(&self) -> &str;
291    }
292
293    struct ProductA {
294        value: String,
295    }
296
297    impl ProductA {
298        #[allow(dead_code)]
299        fn new(value: &str) -> Self {
300            Self {
301                value: value.to_string(),
302            }
303        }
304    }
305
306    impl TestProduct for ProductA {
307        fn get_value(&self) -> &str {
308            &self.value
309        }
310    }
311
312    impl Default for ProductA {
313        fn default() -> Self {
314            Self {
315                value: "default_a".to_string(),
316            }
317        }
318    }
319
320    struct ProductB {
321        value: String,
322    }
323
324    impl ProductB {
325        #[allow(dead_code)]
326        fn new(value: &str) -> Self {
327            Self {
328                value: value.to_string(),
329            }
330        }
331    }
332
333    impl TestProduct for ProductB {
334        fn get_value(&self) -> &str {
335            &self.value
336        }
337    }
338
339    impl Default for ProductB {
340        fn default() -> Self {
341            Self {
342                value: "default_b".to_string(),
343            }
344        }
345    }
346
347    // 注册测试工厂
348    register_factory!(dyn TestProduct, "product_a", ProductA);
349    register_factory!(dyn TestProduct, "product_b", ProductB);
350
351    #[test]
352    fn test_factory_registration() {
353        let factory = FactoryRegistry::<dyn TestProduct>::simple_factory();
354
355        // 测试我们注册的工厂应该存在
356        let result_a = factory.create("product_a", FactoryFallback::NoFallback);
357        assert!(result_a.is_ok(), "product_a factory should exist");
358
359        let result_b = factory.create("product_b", FactoryFallback::NoFallback);
360        assert!(result_b.is_ok(), "product_b factory should exist");
361    }
362
363    #[test]
364    fn test_factory_creation() {
365        let factory = FactoryRegistry::<dyn TestProduct>::simple_factory();
366
367        // 测试创建 ProductA
368        let result = factory.create("product_a", FactoryFallback::NoFallback);
369        assert!(result.is_ok());
370
371        let (id, product) = result.unwrap();
372        assert_eq!(id, "product_a");
373        assert_eq!(product.get_value(), "default_a");
374
375        // 测试创建 ProductB
376        let result = factory.create("product_b", FactoryFallback::NoFallback);
377        assert!(result.is_ok());
378
379        let (id, product) = result.unwrap();
380        assert_eq!(id, "product_b");
381        assert_eq!(product.get_value(), "default_b");
382    }
383
384    #[test]
385    fn test_factory_error_cases() {
386        let factory = FactoryRegistry::<dyn TestProduct>::simple_factory();
387
388        // 测试不存在的工厂 ID
389        let result = factory.create("non_existent", FactoryFallback::NoFallback);
390        assert!(result.is_err());
391
392        if let Err(FactoryError::FactoryNotFound(id)) = result {
393            assert_eq!(id, "non_existent");
394        } else {
395            panic!("Expected FactoryNotFound error");
396        }
397
398        // 测试空 ID 无回退
399        let result = factory.create("", FactoryFallback::NoFallback);
400        assert!(result.is_err());
401
402        if let Err(FactoryError::EmptyIdNoFallback) = result {
403            // 正确
404        } else {
405            panic!("Expected EmptyIdNoFallback error");
406        }
407    }
408
409    #[test]
410    fn test_factory_fallback_first() {
411        let factory = FactoryRegistry::<dyn TestProduct>::simple_factory();
412
413        // 测试 First 回退策略(空 ID)
414        let result = factory.create("", FactoryFallback::First);
415        // 由于 inventory 是全局的,可能在其他测试中注册了工厂
416        // 所以这里可能成功也可能失败,我们只检查行为是否正确
417        match result {
418            Ok((id, _)) => {
419                // 如果成功,id 不应该为空
420                assert!(!id.is_empty());
421            }
422            Err(FactoryError::NoFactoriesAvailable) => {
423                // 如果没有工厂可用,这也是有效的
424            }
425            Err(e) => {
426                // 其他错误不应该发生
427                panic!("Unexpected error: {:?}", e);
428            }
429        }
430
431        // 测试 First 回退策略(无效 ID)
432        let result = factory.create("invalid_id", FactoryFallback::First);
433        match result {
434            Ok((id, _)) => {
435                // 如果成功,id 不应该为空
436                assert!(!id.is_empty());
437            }
438            Err(FactoryError::FactoryNotFound(id)) => {
439                // 如果找不到工厂,id 应该是 "invalid_id"
440                assert_eq!(id, "invalid_id");
441            }
442            Err(FactoryError::NoFactoriesAvailable) => {
443                // 如果没有工厂可用,这也是有效的
444            }
445            Err(e) => {
446                // 其他错误不应该发生
447                panic!("Unexpected error: {:?}", e);
448            }
449        }
450    }
451
452    #[test]
453    fn test_factory_fallback_last() {
454        let factory = FactoryRegistry::<dyn TestProduct>::simple_factory();
455
456        // 测试 Last 回退策略(空 ID)
457        let result = factory.create("", FactoryFallback::Last);
458        // 由于 inventory 是全局的,可能在其他测试中注册了工厂
459        // 所以这里可能成功也可能失败,我们只检查行为是否正确
460        match result {
461            Ok((id, _)) => {
462                // 如果成功,id 不应该为空
463                assert!(!id.is_empty());
464            }
465            Err(FactoryError::NoFactoriesAvailable) => {
466                // 如果没有工厂可用,这也是有效的
467            }
468            Err(e) => {
469                // 其他错误不应该发生
470                panic!("Unexpected error: {:?}", e);
471            }
472        }
473
474        // 测试 Last 回退策略(无效 ID)
475        let result = factory.create("invalid_id", FactoryFallback::Last);
476        match result {
477            Ok((id, _)) => {
478                // 如果成功,id 不应该为空
479                assert!(!id.is_empty());
480            }
481            Err(FactoryError::FactoryNotFound(id)) => {
482                // 如果找不到工厂,id 应该是 "invalid_id"
483                assert_eq!(id, "invalid_id");
484            }
485            Err(FactoryError::NoFactoriesAvailable) => {
486                // 如果没有工厂可用,这也是有效的
487            }
488            Err(e) => {
489                // 其他错误不应该发生
490                panic!("Unexpected error: {:?}", e);
491            }
492        }
493    }
494
495    #[test]
496    fn test_factory_no_factories_available() {
497        // 测试没有工厂的情况
498        // 创建一个新的 trait 和工厂注册表,但不注册任何工厂
499        trait EmptyProduct {
500            #[allow(dead_code)]
501            fn dummy(&self);
502        }
503
504        let factory = FactoryRegistry::<dyn EmptyProduct>::simple_factory();
505
506        // 测试空工厂集合
507        let result = factory.create("", FactoryFallback::First);
508        assert!(result.is_err());
509
510        if let Err(FactoryError::NoFactoriesAvailable) = result {
511            // 正确
512        } else {
513            panic!("Expected NoFactoriesAvailable error");
514        }
515
516        let result = factory.create("", FactoryFallback::Last);
517        assert!(result.is_err());
518
519        if let Err(FactoryError::NoFactoriesAvailable) = result {
520            // 正确
521        } else {
522            panic!("Expected NoFactoriesAvailable error");
523        }
524    }
525
526    #[test]
527    fn test_factory_registry_new() {
528        // 测试 FactoryRegistry::new 函数
529        struct TestFactory;
530
531        impl Factory<String> for TestFactory {
532            fn create(&self) -> Box<String> {
533                Box::new("test".to_string())
534            }
535        }
536
537        let factory = &TestFactory as &'static (dyn Factory<String> + Sync);
538        let registry = FactoryRegistry::new("test_id", factory);
539
540        assert_eq!(registry.id, "test_id");
541        assert_eq!(registry.type_id, TypeId::of::<String>());
542    }
543
544    #[test]
545    fn test_factory_error_display() {
546        // 测试错误信息的显示
547        let error = FactoryError::FactoryNotFound("test_id".to_string());
548        assert_eq!(format!("{}", error), "未找到 ID 为 'test_id' 的工厂");
549
550        let error = FactoryError::EmptyIdNoFallback;
551        assert_eq!(format!("{}", error), "不允许回退时提供了空 ID");
552
553        let error = FactoryError::NoFactoriesAvailable;
554        assert_eq!(format!("{}", error), "没有可用的工厂");
555    }
556
557    #[test]
558    fn test_factory_fallback_debug() {
559        // 测试 FactoryFallback 的 Debug 实现
560        assert_eq!(format!("{:?}", FactoryFallback::First), "First");
561        assert_eq!(format!("{:?}", FactoryFallback::Last), "Last");
562        assert_eq!(format!("{:?}", FactoryFallback::NoFallback), "NoFallback");
563    }
564
565    #[test]
566    fn test_factory_fallback_eq() {
567        // 测试 FactoryFallback 的相等性
568        assert_eq!(FactoryFallback::First, FactoryFallback::First);
569        assert_eq!(FactoryFallback::Last, FactoryFallback::Last);
570        assert_eq!(FactoryFallback::NoFallback, FactoryFallback::NoFallback);
571        assert_ne!(FactoryFallback::First, FactoryFallback::Last);
572        assert_ne!(FactoryFallback::First, FactoryFallback::NoFallback);
573    }
574
575    #[test]
576    fn test_simple_factory_debug() {
577        // 测试 SimpleFactory 的 Debug 实现
578        // SimpleFactory 没有实现 Debug,所以跳过这个测试
579        // 或者我们可以测试工厂是否正常工作
580        let factory = FactoryRegistry::<dyn TestProduct>::simple_factory();
581        let result = factory.create("product_a", FactoryFallback::NoFallback);
582        assert!(result.is_ok());
583    }
584}