Skip to main content

rust_pattern_components/
factory.rs

1use std::{any::TypeId, collections::BTreeMap, fmt::Debug};
2
3use inventory::{Collect, Registry};
4use thiserror::Error;
5
6/// 用于创建类型 `T` 实例的工厂 trait。
7///
8/// 实现此 trait 的类型可以创建目标类型的装箱实例。
9/// 类型 `T` 必须是 `Send + Sync` 并且可以是非固定大小类型。
10pub trait Factory<T: ?Sized> {
11    fn create(&self) -> Box<T>;
12}
13
14/// Errors that can occur when creating objects through a factory.
15#[derive(Debug, Error)]
16pub enum FactoryError {
17    /// The specified factory was not found.
18    #[error("factory with ID '{0}' not found")]
19    FactoryNotFound(String),
20
21    /// An empty ID was provided when fallback is not allowed.
22    #[error("empty ID provided without fallback")]
23    EmptyIdNoFallback,
24
25    /// No factories available for the requested product type.
26    #[error("no factories available")]
27    NoFactoriesAvailable,
28}
29
30/// 当找不到指定 ID 的工厂时的处理策略。
31#[derive(Debug, Clone, Copy, PartialEq, Eq)]
32pub enum FactoryFallback {
33    /// 如果找不到指定的工厂,则使用集合中的第一个工厂。
34    First,
35
36    /// 如果找不到指定的工厂,则使用集合中的最后一个工厂。
37    Last,
38
39    /// ID 为空时不回退,直接返回错误。
40    NoFallback,
41}
42
43/// 用于创建类型 `T` 实例的工厂集合。
44///
45/// 此结构体包含一个从工厂 ID 到工厂实例的映射,这些工厂实例可以创建
46/// 目标类型 `T` 的装箱实例。工厂在编译时使用 `inventory` crate 注册,
47/// 可以通过 `FactoryRegistry::factories()` 检索。
48///
49/// 类型 `T` 可以是非固定大小类型(trait 对象),并且必须具有 `'static` 生命周期。
50/// 工厂存储为静态引用,允许它们在线程间共享。
51///
52/// # 示例
53///
54/// 基本用法:
55///
56/// ```rust,no_run
57/// use rust_pattern_components::{FactoryFallback, FactoryRegistry};
58///
59/// // 定义产品 trait
60/// trait Product {
61///     fn name(&self) -> &str;
62/// }
63///
64/// // 假设已经注册了工厂(通过 inventory 机制)
65/// // register_factory!(dyn Product, "product_a", ProductA);
66/// // register_factory!(dyn Product, "product_b", ProductB);
67///
68/// // 获取工厂实例
69/// let factory = FactoryRegistry::<dyn Product>::simple_factory();
70///
71/// // 通过 ID 创建特定产品
72/// match factory.create("product_a", FactoryFallback::NoFallback) {
73///     Ok(product) => {
74///         println!("创建了产品: {}", product.name());
75///     }
76///     Err(e) => {
77///         println!("创建失败: {}", e);
78///     }
79/// }
80///
81/// // 使用回退策略
82/// let result = factory.create("", FactoryFallback::First);
83/// // 当 ID 为空时,使用第一个可用的工厂
84///
85/// let result = factory.create("", FactoryFallback::Last);
86/// // 当 ID 为空时,使用最后一个可用的工厂
87/// ```
88///
89/// 错误处理:
90///
91/// ```rust,no_run
92/// use rust_pattern_components::{FactoryFallback, FactoryRegistry, FactoryError};
93///
94/// // 定义产品 trait
95/// trait Product {
96///     fn name(&self) -> &str;
97/// }
98///
99/// let factory = FactoryRegistry::<dyn Product>::simple_factory();
100///
101/// // 不存在的工厂 ID
102/// match factory.create("nonexistent", FactoryFallback::NoFallback) {
103///     Err(FactoryError::FactoryNotFound(id)) => {
104///         println!("未找到工厂: {}", id);
105///     }
106///     _ => {}
107/// }
108///
109/// // 空 ID 且无回退策略
110/// match factory.create("", FactoryFallback::NoFallback) {
111///     Err(FactoryError::EmptyIdNoFallback) => {
112///         println!("空 ID 且未指定回退策略");
113///     }
114///     _ => {}
115/// }
116///
117/// // 没有可用的工厂
118/// match factory.create("any", FactoryFallback::NoFallback) {
119///     Err(FactoryError::NoFactoriesAvailable) => {
120///         println!("没有可用的工厂");
121///     }
122///     _ => {}
123/// }
124/// ```
125pub struct SimpleFactory<T: ?Sized + 'static>(
126    BTreeMap<&'static str, &'static (dyn Factory<T> + Sync)>,
127);
128
129impl<T> SimpleFactory<T>
130where
131    T: ?Sized + 'static,
132{
133    /// 使用指定的回退策略通过工厂模式创建实例。
134    ///
135    /// 此函数通过 ID 查找工厂并使用它创建实例。
136    /// 如果 `id` 为空,行为取决于 `strategy`:
137    /// - `NoFallback`:返回错误
138    /// - `First`:使用集合中的第一个工厂
139    /// - `Last`:使用集合中的最后一个工厂
140    ///   如果 `id` 不为空但找不到工厂,行为由 `strategy` 决定。
141    ///
142    /// # 参数
143    /// * `id` - 要使用的工厂标识符,或空字符串表示默认
144    /// * `strategy` - 找不到指定 ID 的工厂时使用的策略
145    ///
146    /// # 返回值
147    /// * `Ok((&str, Box<T>))` - 成功时返回包含使用的工厂 ID 和创建的实例的元组
148    /// * `Err(FactoryError)` - 如果找不到工厂或没有可用的工厂则返回错误
149    pub fn create(
150        &self,
151        id: impl AsRef<str>,
152        strategy: FactoryFallback,
153    ) -> Result<Box<T>, FactoryError> {
154        let id = id.as_ref();
155        if !id.is_empty() {
156            return if let Some(factory) = self.0.get(id) {
157                Ok(factory.create())
158            } else {
159                Err(FactoryError::FactoryNotFound(id.to_string()))
160            };
161        }
162
163        match strategy {
164            FactoryFallback::First => {
165                if let Some((_, factory)) = self.0.first_key_value() {
166                    return Ok(factory.create());
167                }
168            }
169            FactoryFallback::Last => {
170                if let Some((_, factory)) = self.0.last_key_value() {
171                    return Ok(factory.create());
172                }
173            }
174            FactoryFallback::NoFallback => return Err(FactoryError::EmptyIdNoFallback),
175        }
176
177        Err(FactoryError::NoFactoriesAvailable)
178    }
179}
180
181/// 工厂实现的注册表条目。
182///
183/// 存储工厂的元数据,包括其 ID、产品类型 ID 和工厂实例。
184/// 此类型与 `inventory` crate 一起用于编译时注册。
185pub struct FactoryRegistry<T>
186where
187    T: ?Sized + 'static,
188{
189    /// 此工厂的唯一标识符。
190    ///
191    /// 此 ID 用于在创建实例时查找工厂。
192    /// 它必须是静态字符串字面量,并且对于给定产品类型 `T` 在注册表中应该是唯一的。
193    id: &'static str,
194
195    /// 创建类型 `T` 实例的工厂实例。
196    ///
197    /// 这是一个静态引用,指向可以创建产品类型 `T` 的装箱实例的工厂实现。
198    /// 工厂必须是线程安全的(`Sync`)以允许在线程间共享。
199    factory: &'static (dyn Factory<T> + Sync),
200
201    /// 产品类型 `T` 的类型标识符。
202    ///
203    /// 此字段存储产品类型 `T` 的 `TypeId`,用于在从注册表检索时按产品类型过滤工厂。
204    /// 它确保只有正确产品类型的工厂包含在工厂集合中。
205    type_id: TypeId,
206}
207
208impl<T> Collect for FactoryRegistry<T>
209where
210    T: ?Sized + 'static,
211{
212    fn registry() -> &'static Registry {
213        static REGISTRY: Registry = Registry::new();
214
215        &REGISTRY
216    }
217}
218
219impl<T> FactoryRegistry<T>
220where
221    T: ?Sized + 'static,
222{
223    /// 创建一个新的工厂注册表条目。
224    ///
225    /// # 参数
226    /// * `id` - 此工厂的唯一标识符
227    /// * `factory` - 创建产品的工厂实例
228    #[inline]
229    pub const fn new(id: &'static str, factory: &'static (dyn Factory<T> + Sync)) -> Self {
230        Self {
231            id,
232            factory,
233            type_id: TypeId::of::<T>(),
234        }
235    }
236
237    /// 查找产品类型 `T` 的所有已注册工厂。
238    ///
239    /// 此函数扫描编译时工厂注册表,并返回一个 `SimpleFactory` 实例,
240    /// 该实例包含一个从工厂 ID 到工厂实例的映射,这些工厂实例生产类型 `T` 的实例。
241    /// 只包含为确切产品类型 `T` 注册的工厂。
242    ///
243    /// # 返回值
244    /// 一个 `SimpleFactory<T>` 实例,包装一个 `BTreeMap`,其中:
245    /// - 键是工厂的静态字符串标识符
246    /// - 值是实现 `Factory<T>` 的工厂实例的引用
247    pub fn simple_factory() -> SimpleFactory<T> {
248        let type_id = TypeId::of::<T>();
249        let factories = inventory::iter::<Self>()
250            .filter_map(|reg| (type_id == reg.type_id).then_some((reg.id, reg.factory)))
251            .collect();
252
253        SimpleFactory(factories)
254    }
255}
256
257/// 为产品类型注册工厂实现的宏。
258///
259/// 注册一个工厂,该工厂创建 `$implement` 的实例作为 `$product` trait 的实现。
260#[macro_export]
261macro_rules! register_factory {
262    ($product:ty, $id:literal, $implement:ty) => {
263        $crate::const_assert!(!$id.is_empty());
264        $crate::assert_impl_one!($implement: Default);
265
266        const _: () = {
267            struct ConcreteFactory;
268
269            impl $crate::Factory<$product> for ConcreteFactory {
270                fn create(&self) -> Box<$product> {
271                    Box::<$implement>::default()
272                }
273            }
274
275            $crate::submit! {
276                $crate::FactoryRegistry::new(
277                    $id,
278                    &ConcreteFactory as &'static (dyn $crate::Factory<$product> + Sync),
279                )
280            }
281        };
282    };
283}
284
285#[cfg(test)]
286mod tests {
287    use super::*;
288
289    // 测试用的 trait 和实现
290    trait TestProduct {
291        fn get_value(&self) -> &str;
292    }
293
294    struct ProductA {
295        value: String,
296    }
297
298    impl ProductA {
299        #[allow(dead_code)]
300        fn new(value: &str) -> Self {
301            Self {
302                value: value.to_string(),
303            }
304        }
305    }
306
307    impl TestProduct for ProductA {
308        fn get_value(&self) -> &str {
309            &self.value
310        }
311    }
312
313    impl Default for ProductA {
314        fn default() -> Self {
315            Self {
316                value: "default_a".to_string(),
317            }
318        }
319    }
320
321    struct ProductB {
322        value: String,
323    }
324
325    impl ProductB {
326        #[allow(dead_code)]
327        fn new(value: &str) -> Self {
328            Self {
329                value: value.to_string(),
330            }
331        }
332    }
333
334    impl TestProduct for ProductB {
335        fn get_value(&self) -> &str {
336            &self.value
337        }
338    }
339
340    impl Default for ProductB {
341        fn default() -> Self {
342            Self {
343                value: "default_b".to_string(),
344            }
345        }
346    }
347
348    // 注册测试工厂
349    register_factory!(dyn TestProduct, "product_a", ProductA);
350    register_factory!(dyn TestProduct, "product_b", ProductB);
351
352    #[test]
353    fn test_factory_registration() {
354        let factory = FactoryRegistry::<dyn TestProduct>::simple_factory();
355
356        // 测试我们注册的工厂应该存在
357        let result_a = factory.create("product_a", FactoryFallback::NoFallback);
358        assert!(result_a.is_ok(), "product_a factory should exist");
359
360        let result_b = factory.create("product_b", FactoryFallback::NoFallback);
361        assert!(result_b.is_ok(), "product_b factory should exist");
362    }
363
364    #[test]
365    fn test_factory_creation() {
366        let factory = FactoryRegistry::<dyn TestProduct>::simple_factory();
367
368        // 测试创建 ProductA
369        let result = factory.create("product_a", FactoryFallback::NoFallback);
370        assert!(result.is_ok());
371
372        let product = result.unwrap();
373        assert_eq!(product.get_value(), "default_a");
374
375        // 测试创建 ProductB
376        let result = factory.create("product_b", FactoryFallback::NoFallback);
377        assert!(result.is_ok());
378
379        let product = result.unwrap();
380        assert_eq!(product.get_value(), "default_b");
381    }
382
383    #[test]
384    fn test_factory_error_cases() {
385        let factory = FactoryRegistry::<dyn TestProduct>::simple_factory();
386
387        // 测试不存在的工厂 ID
388        let result = factory.create("non_existent", FactoryFallback::NoFallback);
389        assert!(result.is_err());
390
391        if let Err(FactoryError::FactoryNotFound(id)) = result {
392            assert_eq!(id, "non_existent");
393        } else {
394            panic!("Expected FactoryNotFound error");
395        }
396
397        // 测试空 ID 无回退
398        let result = factory.create("", FactoryFallback::NoFallback);
399        assert!(result.is_err());
400
401        if let Err(FactoryError::EmptyIdNoFallback) = result {
402            // 正确
403        } else {
404            panic!("Expected EmptyIdNoFallback error");
405        }
406    }
407
408    #[test]
409    fn test_factory_fallback_first() {
410        let factory = FactoryRegistry::<dyn TestProduct>::simple_factory();
411
412        // 测试 First 回退策略(空 ID)
413        let result = factory.create("", FactoryFallback::First);
414        // 由于 inventory 是全局的,可能在其他测试中注册了工厂
415        // 所以这里可能成功也可能失败,我们只检查行为是否正确
416        match result {
417            Ok(_product) => {
418                // 回退成功
419            }
420            Err(FactoryError::NoFactoriesAvailable) => {
421                // 如果没有工厂可用,这也是有效的
422            }
423            Err(e) => {
424                // 其他错误不应该发生
425                panic!("Unexpected error: {:?}", e);
426            }
427        }
428
429        // 测试 First 回退策略(无效 ID)
430        let result = factory.create("invalid_id", FactoryFallback::First);
431        match result {
432            Ok(_product) => {
433                panic!("Expected FactoryNotFound for invalid ID");
434            }
435            Err(FactoryError::FactoryNotFound(id)) => {
436                assert_eq!(id, "invalid_id");
437            }
438            Err(e) => {
439                // 其他错误不应该发生
440                panic!("Unexpected error: {:?}", e);
441            }
442        }
443    }
444
445    #[test]
446    fn test_factory_fallback_last() {
447        let factory = FactoryRegistry::<dyn TestProduct>::simple_factory();
448
449        // 测试 Last 回退策略(空 ID)
450        let result = factory.create("", FactoryFallback::Last);
451        // 由于 inventory 是全局的,可能在其他测试中注册了工厂
452        // 所以这里可能成功也可能失败,我们只检查行为是否正确
453        match result {
454            Ok(_product) => {
455                // 回退成功
456            }
457            Err(FactoryError::NoFactoriesAvailable) => {
458                // 如果没有工厂可用,这也是有效的
459            }
460            Err(e) => {
461                // 其他错误不应该发生
462                panic!("Unexpected error: {:?}", e);
463            }
464        }
465
466        // 测试 Last 回退策略(无效 ID)
467        let result = factory.create("invalid_id", FactoryFallback::Last);
468        match result {
469            Ok(_product) => {
470                panic!("Expected FactoryNotFound for invalid ID");
471            }
472            Err(FactoryError::FactoryNotFound(id)) => {
473                assert_eq!(id, "invalid_id");
474            }
475            Err(e) => {
476                // 其他错误不应该发生
477                panic!("Unexpected error: {:?}", e);
478            }
479        }
480    }
481
482    #[test]
483    fn test_factory_no_factories_available() {
484        // 测试没有工厂的情况
485        // 创建一个新的 trait 和工厂注册表,但不注册任何工厂
486        trait EmptyProduct {
487            #[allow(dead_code)]
488            fn dummy(&self);
489        }
490
491        let factory = FactoryRegistry::<dyn EmptyProduct>::simple_factory();
492
493        // 测试空工厂集合
494        let result = factory.create("", FactoryFallback::First);
495        assert!(result.is_err());
496
497        if let Err(FactoryError::NoFactoriesAvailable) = result {
498            // 正确
499        } else {
500            panic!("Expected NoFactoriesAvailable error");
501        }
502
503        let result = factory.create("", FactoryFallback::Last);
504        assert!(result.is_err());
505
506        if let Err(FactoryError::NoFactoriesAvailable) = result {
507            // 正确
508        } else {
509            panic!("Expected NoFactoriesAvailable error");
510        }
511    }
512
513    #[test]
514    fn test_factory_registry_new() {
515        // 测试 FactoryRegistry::new 函数
516        struct TestFactory;
517
518        impl Factory<String> for TestFactory {
519            fn create(&self) -> Box<String> {
520                Box::new("test".to_string())
521            }
522        }
523
524        let factory = &TestFactory as &'static (dyn Factory<String> + Sync);
525        let registry = FactoryRegistry::new("test_id", factory);
526
527        assert_eq!(registry.id, "test_id");
528        assert_eq!(registry.type_id, TypeId::of::<String>());
529    }
530
531    #[test]
532    fn test_factory_error_display() {
533        // 测试错误信息的显示
534        let error = FactoryError::FactoryNotFound("test_id".to_string());
535        assert_eq!(format!("{}", error), "factory with ID 'test_id' not found");
536
537        let error = FactoryError::EmptyIdNoFallback;
538        assert_eq!(format!("{}", error), "empty ID provided without fallback");
539
540        let error = FactoryError::NoFactoriesAvailable;
541        assert_eq!(format!("{}", error), "no factories available");
542    }
543
544    #[test]
545    fn test_factory_fallback_debug() {
546        // 测试 FactoryFallback 的 Debug 实现
547        assert_eq!(format!("{:?}", FactoryFallback::First), "First");
548        assert_eq!(format!("{:?}", FactoryFallback::Last), "Last");
549        assert_eq!(format!("{:?}", FactoryFallback::NoFallback), "NoFallback");
550    }
551
552    #[test]
553    fn test_factory_fallback_eq() {
554        // 测试 FactoryFallback 的相等性
555        assert_eq!(FactoryFallback::First, FactoryFallback::First);
556        assert_eq!(FactoryFallback::Last, FactoryFallback::Last);
557        assert_eq!(FactoryFallback::NoFallback, FactoryFallback::NoFallback);
558        assert_ne!(FactoryFallback::First, FactoryFallback::Last);
559        assert_ne!(FactoryFallback::First, FactoryFallback::NoFallback);
560    }
561
562    #[test]
563    fn test_simple_factory_debug() {
564        // 测试 SimpleFactory 的 Debug 实现
565        // SimpleFactory 没有实现 Debug,所以跳过这个测试
566        // 或者我们可以测试工厂是否正常工作
567        let factory = FactoryRegistry::<dyn TestProduct>::simple_factory();
568        let result = factory.create("product_a", FactoryFallback::NoFallback);
569        assert!(result.is_ok());
570    }
571}