springtime_di/
factory.rs

1//! Core functionality for creating [Component](crate::component::Component) instances.
2
3use crate::component_registry::conditional::SimpleContextFactory;
4use crate::component_registry::{
5    ComponentDefinition, ComponentDefinitionRegistry, ComponentDefinitionRegistryError,
6    StaticComponentDefinitionRegistry,
7};
8use crate::instance_provider::{
9    CastFunction, ComponentInstanceAnyPtr, ComponentInstanceProvider,
10    ComponentInstanceProviderError,
11};
12use crate::scope::{
13    PrototypeScopeFactory, ScopeFactory, ScopePtr, SingletonScopeFactory, PROTOTYPE, SINGLETON,
14};
15#[cfg(feature = "async")]
16use futures::future::BoxFuture;
17#[cfg(feature = "async")]
18use futures::FutureExt;
19#[cfg(not(feature = "async"))]
20use itertools::Itertools;
21use rustc_hash::{FxHashMap, FxHashSet};
22use std::any::TypeId;
23use tracing::debug;
24
25#[cfg(not(feature = "threadsafe"))]
26pub type ComponentDefinitionRegistryPtr = Box<dyn ComponentDefinitionRegistry>;
27#[cfg(feature = "threadsafe")]
28pub type ComponentDefinitionRegistryPtr = Box<dyn ComponentDefinitionRegistry + Send + Sync>;
29
30#[cfg(not(feature = "threadsafe"))]
31pub type ScopeFactoryPtr = Box<dyn ScopeFactory>;
32#[cfg(feature = "threadsafe")]
33pub type ScopeFactoryPtr = Box<dyn ScopeFactory + Send + Sync>;
34
35pub type ScopeFactoryRegistry = FxHashMap<String, ScopeFactoryPtr>;
36
37/// Builder for [ComponentFactory] with sensible defaults, for easy construction.
38pub struct ComponentFactoryBuilder {
39    definition_registry: ComponentDefinitionRegistryPtr,
40    scope_factories: ScopeFactoryRegistry,
41}
42
43impl ComponentFactoryBuilder {
44    /// Creates a new builder with a default configuration.
45    pub fn new() -> Result<Self, ComponentDefinitionRegistryError> {
46        Ok(Self {
47            definition_registry: Box::new(StaticComponentDefinitionRegistry::new(
48                true,
49                &SimpleContextFactory,
50            )?),
51            scope_factories: [
52                (
53                    SINGLETON.to_string(),
54                    Box::<SingletonScopeFactory>::default() as ScopeFactoryPtr,
55                ),
56                (
57                    PROTOTYPE.to_string(),
58                    Box::<PrototypeScopeFactory>::default() as ScopeFactoryPtr,
59                ),
60            ]
61            .into_iter()
62            .collect(),
63        })
64    }
65
66    /// Sets new [ComponentDefinitionRegistry].
67    pub fn with_definition_registry(
68        mut self,
69        definition_registry: ComponentDefinitionRegistryPtr,
70    ) -> Self {
71        self.definition_registry = definition_registry;
72        self
73    }
74
75    /// Sets new scope factories.
76    pub fn with_scope_factories(mut self, scope_factories: ScopeFactoryRegistry) -> Self {
77        self.scope_factories = scope_factories;
78        self
79    }
80
81    /// Adds a new scope factory.
82    pub fn with_scope_factory<T: ToString>(mut self, name: T, factory: ScopeFactoryPtr) -> Self {
83        self.scope_factories.insert(name.to_string(), factory);
84        self
85    }
86
87    /// Builds resulting [ComponentFactory].
88    pub fn build(self) -> ComponentFactory {
89        ComponentFactory::new(self.definition_registry, self.scope_factories)
90    }
91}
92
93/// Generic factory for [Component](crate::component::Component) instances. Uses definitions from
94/// the [ComponentDefinitionRegistry] and [scopes](crate::scope) to create and store instances for
95/// reuse.
96pub struct ComponentFactory {
97    definition_registry: ComponentDefinitionRegistryPtr,
98    scope_factories: FxHashMap<String, ScopeFactoryPtr>,
99    scopes: FxHashMap<String, ScopePtr>,
100    types_under_construction: FxHashSet<TypeId>,
101}
102
103impl ComponentFactory {
104    /// Creates a new factory with given registry and scope factories. The factory map should
105    /// include built-in [SINGLETON] and [PROTOTYPE] for maximum compatibility with components,
106    /// since they are usually the most popular. This is not a hard requirement, but care needs to
107    /// be taken to ensue no component uses them.
108    pub fn new(
109        definition_registry: ComponentDefinitionRegistryPtr,
110        scope_factories: FxHashMap<String, ScopeFactoryPtr>,
111    ) -> Self {
112        Self {
113            definition_registry,
114            scope_factories,
115            scopes: Default::default(),
116            types_under_construction: Default::default(),
117        }
118    }
119
120    #[cfg(feature = "async")]
121    async fn call_constructor(
122        &mut self,
123        definition: &ComponentDefinition,
124    ) -> Result<ComponentInstanceAnyPtr, ComponentInstanceProviderError> {
125        self.types_under_construction
126            .insert(definition.resolved_type_id);
127        let instance = (definition.constructor)(self).await;
128        self.types_under_construction
129            .remove(&definition.resolved_type_id);
130
131        instance
132    }
133
134    #[cfg(not(feature = "async"))]
135    fn call_constructor(
136        &mut self,
137        definition: &ComponentDefinition,
138    ) -> Result<ComponentInstanceAnyPtr, ComponentInstanceProviderError> {
139        self.types_under_construction
140            .insert(definition.resolved_type_id);
141        let instance = (definition.constructor)(self);
142        self.types_under_construction
143            .remove(&definition.resolved_type_id);
144
145        instance
146    }
147
148    fn check_scope_instance(
149        &mut self,
150        definition: &ComponentDefinition,
151    ) -> Result<Option<(ComponentInstanceAnyPtr, CastFunction)>, ComponentInstanceProviderError>
152    {
153        if self
154            .types_under_construction
155            .contains(&definition.resolved_type_id)
156        {
157            return Err(ComponentInstanceProviderError::DependencyCycle {
158                type_id: definition.resolved_type_id,
159                type_name: None,
160            });
161        }
162
163        let scope = {
164            if let Some(scope) = self.scopes.get(&definition.scope) {
165                scope
166            } else {
167                let factory = self.scope_factories.get(&definition.scope).ok_or_else(|| {
168                    ComponentInstanceProviderError::UnrecognizedScope(definition.scope.to_string())
169                })?;
170
171                self.scopes
172                    .entry(definition.scope.clone())
173                    .or_insert(factory.create_scope())
174            }
175        };
176
177        Ok(scope
178            .instance(definition)
179            .map(|instance| (instance, definition.cast)))
180    }
181
182    fn store_instance_in_scope(
183        &mut self,
184        definition: &ComponentDefinition,
185        instance: ComponentInstanceAnyPtr,
186    ) -> Result<(), ComponentInstanceProviderError> {
187        let scope = self.scopes.get_mut(&definition.scope).ok_or_else(|| {
188            ComponentInstanceProviderError::UnrecognizedScope(definition.scope.to_string())
189        })?;
190
191        scope.store_instance(definition, instance);
192
193        Ok(())
194    }
195
196    #[cfg(feature = "async")]
197    async fn create_instance(
198        &mut self,
199        definition: &ComponentDefinition,
200    ) -> Result<(ComponentInstanceAnyPtr, CastFunction), ComponentInstanceProviderError> {
201        if let Some(instance) = self.check_scope_instance(definition)? {
202            return Ok(instance);
203        }
204
205        debug!(
206            resolved_type_name = definition.resolved_type_name,
207            "Creating new component instance."
208        );
209
210        let instance = self.call_constructor(definition).await?;
211
212        self.store_instance_in_scope(definition, instance.clone())?;
213        Ok((instance, definition.cast))
214    }
215
216    #[cfg(not(feature = "async"))]
217    fn create_instance(
218        &mut self,
219        definition: &ComponentDefinition,
220    ) -> Result<(ComponentInstanceAnyPtr, CastFunction), ComponentInstanceProviderError> {
221        if let Some(instance) = self.check_scope_instance(definition)? {
222            return Ok(instance);
223        }
224
225        debug!(
226            resolved_type_name = definition.resolved_type_name,
227            "Creating new component instance."
228        );
229
230        let instance = self.call_constructor(definition)?;
231
232        self.store_instance_in_scope(definition, instance.clone())?;
233        Ok((instance, definition.cast))
234    }
235}
236
237impl ComponentInstanceProvider for ComponentFactory {
238    #[cfg(feature = "async")]
239    fn primary_instance(
240        &mut self,
241        type_id: TypeId,
242    ) -> BoxFuture<
243        '_,
244        Result<(ComponentInstanceAnyPtr, CastFunction), ComponentInstanceProviderError>,
245    > {
246        async move {
247            let definition = self.definition_registry.primary_component(type_id).ok_or(
248                ComponentInstanceProviderError::NoPrimaryInstance {
249                    type_id,
250                    type_name: None,
251                },
252            )?;
253
254            self.create_instance(&definition).await
255        }
256        .boxed()
257    }
258
259    #[cfg(not(feature = "async"))]
260    fn primary_instance(
261        &mut self,
262        type_id: TypeId,
263    ) -> Result<(ComponentInstanceAnyPtr, CastFunction), ComponentInstanceProviderError> {
264        let definition = self.definition_registry.primary_component(type_id).ok_or(
265            ComponentInstanceProviderError::NoPrimaryInstance {
266                type_id,
267                type_name: None,
268            },
269        )?;
270
271        self.create_instance(&definition)
272    }
273
274    #[cfg(feature = "async")]
275    fn instances(
276        &mut self,
277        type_id: TypeId,
278    ) -> BoxFuture<
279        '_,
280        Result<Vec<(ComponentInstanceAnyPtr, CastFunction)>, ComponentInstanceProviderError>,
281    > {
282        async move {
283            let definitions = self.definition_registry.components_by_type(type_id);
284
285            let mut result = Vec::with_capacity(definitions.len());
286            for definition in &definitions {
287                result.push(self.create_instance(definition).await?);
288            }
289
290            Ok(result)
291        }
292        .boxed()
293    }
294
295    #[cfg(not(feature = "async"))]
296    fn instances(
297        &mut self,
298        type_id: TypeId,
299    ) -> Result<Vec<(ComponentInstanceAnyPtr, CastFunction)>, ComponentInstanceProviderError> {
300        self.definition_registry
301            .components_by_type(type_id)
302            .iter()
303            .map(|definition| self.create_instance(definition))
304            .try_collect()
305    }
306
307    #[cfg(feature = "async")]
308    fn instance_by_name(
309        &mut self,
310        name: &str,
311        type_id: TypeId,
312    ) -> BoxFuture<
313        '_,
314        Result<(ComponentInstanceAnyPtr, CastFunction), ComponentInstanceProviderError>,
315    > {
316        let name = name.to_string();
317        async move {
318            let definition = self
319                .definition_registry
320                .component_by_name(&name, type_id)
321                .ok_or_else(|| ComponentInstanceProviderError::NoNamedInstance(name.to_string()))?;
322
323            self.create_instance(&definition).await
324        }
325        .boxed()
326    }
327
328    #[cfg(not(feature = "async"))]
329    fn instance_by_name(
330        &mut self,
331        name: &str,
332        type_id: TypeId,
333    ) -> Result<(ComponentInstanceAnyPtr, CastFunction), ComponentInstanceProviderError> {
334        let definition = self
335            .definition_registry
336            .component_by_name(name, type_id)
337            .ok_or_else(|| ComponentInstanceProviderError::NoNamedInstance(name.to_string()))?;
338
339        self.create_instance(&definition)
340    }
341}
342
343//noinspection DuplicatedCode
344#[cfg(test)]
345mod tests {
346    #[cfg(not(feature = "async"))]
347    mod sync {
348        use crate::component_registry::{
349            ComponentDefinition, ComponentDefinitionRegistry, MockComponentDefinitionRegistry,
350        };
351        use crate::factory::{ComponentDefinitionRegistryPtr, ComponentFactory, ScopeFactoryPtr};
352        use crate::instance_provider::{
353            ComponentInstanceAnyPtr, ComponentInstanceProvider, ComponentInstanceProviderError,
354            ComponentInstancePtr,
355        };
356        use crate::scope::{
357            MockScope, MockScopeFactory, PrototypeScopeFactory, ScopePtr, PROTOTYPE, SINGLETON,
358        };
359        use mockall::predicate::*;
360        use std::any::{type_name, Any, TypeId};
361
362        fn cast(
363            instance: ComponentInstanceAnyPtr,
364        ) -> Result<Box<dyn Any>, ComponentInstanceAnyPtr> {
365            Err(instance)
366        }
367
368        fn constructor(
369            _instance_provider: &mut dyn ComponentInstanceProvider,
370        ) -> Result<ComponentInstanceAnyPtr, ComponentInstanceProviderError> {
371            Ok(ComponentInstancePtr::new(0) as ComponentInstanceAnyPtr)
372        }
373
374        fn error_constructor(
375            _instance_provider: &mut dyn ComponentInstanceProvider,
376        ) -> Result<ComponentInstanceAnyPtr, ComponentInstanceProviderError> {
377            Err(ComponentInstanceProviderError::NoPrimaryInstance {
378                type_id: TypeId::of::<i8>(),
379                type_name: None,
380            })
381        }
382
383        fn recursive_constructor(
384            instance_provider: &mut dyn ComponentInstanceProvider,
385        ) -> Result<ComponentInstanceAnyPtr, ComponentInstanceProviderError> {
386            instance_provider
387                .primary_instance(TypeId::of::<i8>())
388                .map(|(instance, _)| instance)
389        }
390
391        fn create_definition() -> (ComponentDefinition, TypeId) {
392            (
393                ComponentDefinition {
394                    names: ["name".to_string()].into_iter().collect(),
395                    is_primary: false,
396                    scope: PROTOTYPE.to_string(),
397                    resolved_type_id: TypeId::of::<i8>(),
398                    resolved_type_name: type_name::<i8>().to_string(),
399                    constructor,
400                    cast,
401                },
402                TypeId::of::<i8>(),
403            )
404        }
405
406        fn create_factory<T: ComponentDefinitionRegistry + Send + Sync + 'static>(
407            definition_registry: T,
408        ) -> ComponentFactory {
409            ComponentFactory::new(
410                Box::new(definition_registry) as ComponentDefinitionRegistryPtr,
411                [(
412                    PROTOTYPE.to_string(),
413                    Box::<PrototypeScopeFactory>::default() as ScopeFactoryPtr,
414                )]
415                .into_iter()
416                .collect(),
417            )
418        }
419
420        #[test]
421        fn should_return_primary_instance() {
422            let (definition, id) = create_definition();
423
424            let mut registry = MockComponentDefinitionRegistry::new();
425            registry
426                .expect_primary_component()
427                .with(eq(id))
428                .times(1)
429                .return_const(Some(definition));
430
431            let mut factory = create_factory(registry);
432            assert!(factory.primary_instance(id).is_ok());
433        }
434
435        #[test]
436        fn should_detect_primary_instance_loops() {
437            let id = TypeId::of::<i8>();
438            let definition = ComponentDefinition {
439                names: Default::default(),
440                is_primary: false,
441                scope: PROTOTYPE.to_string(),
442                resolved_type_id: TypeId::of::<i8>(),
443                resolved_type_name: type_name::<i8>().to_string(),
444                constructor: recursive_constructor,
445                cast,
446            };
447
448            let mut registry = MockComponentDefinitionRegistry::new();
449            registry
450                .expect_primary_component()
451                .with(eq(id))
452                .times(2)
453                .return_const(Some(definition));
454
455            let mut factory = create_factory(registry);
456            assert!(matches!(
457                factory.primary_instance(id).unwrap_err(),
458                ComponentInstanceProviderError::DependencyCycle { type_id, .. } if type_id == TypeId::of::<i8>()
459            ));
460        }
461
462        #[test]
463        fn should_not_return_missing_primary_instance() {
464            let id = TypeId::of::<i8>();
465
466            let mut registry = MockComponentDefinitionRegistry::new();
467            registry
468                .expect_primary_component()
469                .with(eq(id))
470                .times(1)
471                .return_const(None);
472
473            let mut factory = create_factory(registry);
474            assert!(matches!(
475                factory.primary_instance(id).unwrap_err(),
476                ComponentInstanceProviderError::NoPrimaryInstance { type_id, .. } if type_id == TypeId::of::<i8>()
477            ));
478        }
479
480        #[test]
481        fn should_recognize_primary_instance_missing_scope() {
482            let id = TypeId::of::<i8>();
483            let definition = ComponentDefinition {
484                names: Default::default(),
485                is_primary: false,
486                scope: SINGLETON.to_string(),
487                resolved_type_id: TypeId::of::<i8>(),
488                resolved_type_name: type_name::<i8>().to_string(),
489                constructor,
490                cast,
491            };
492
493            let mut registry = MockComponentDefinitionRegistry::new();
494            registry
495                .expect_primary_component()
496                .with(eq(id))
497                .times(1)
498                .return_const(Some(definition));
499
500            let mut factory = create_factory(registry);
501            assert!(matches!(
502                factory.primary_instance(id).unwrap_err(),
503                ComponentInstanceProviderError::UnrecognizedScope(scope) if scope == SINGLETON
504            ));
505        }
506
507        #[test]
508        fn should_forward_primary_instance_constructor_error() {
509            let id = TypeId::of::<i8>();
510            let definition = ComponentDefinition {
511                names: Default::default(),
512                is_primary: false,
513                scope: PROTOTYPE.to_string(),
514                resolved_type_id: TypeId::of::<i8>(),
515                resolved_type_name: type_name::<i8>().to_string(),
516                constructor: error_constructor,
517                cast,
518            };
519
520            let mut registry = MockComponentDefinitionRegistry::new();
521            registry
522                .expect_primary_component()
523                .with(eq(id))
524                .times(1)
525                .return_const(Some(definition));
526
527            let mut factory = create_factory(registry);
528            assert!(matches!(
529                factory.primary_instance(id).unwrap_err(),
530                ComponentInstanceProviderError::NoPrimaryInstance { .. }
531            ));
532        }
533
534        #[test]
535        fn should_store_primary_instance_in_scope() {
536            let (definition, id) = create_definition();
537
538            let mut registry = MockComponentDefinitionRegistry::new();
539            registry
540                .expect_primary_component()
541                .with(eq(id))
542                .times(1)
543                .return_const(Some(definition));
544
545            let mut scope_factory = MockScopeFactory::new();
546            scope_factory.expect_create_scope().returning(|| {
547                let mut scope = MockScope::new();
548                scope.expect_store_instance().times(1).return_const(());
549                scope.expect_instance().return_const(None);
550
551                Box::new(scope) as ScopePtr
552            });
553
554            let mut factory = ComponentFactory::new(
555                Box::new(registry) as ComponentDefinitionRegistryPtr,
556                [(
557                    PROTOTYPE.to_string(),
558                    Box::new(scope_factory) as ScopeFactoryPtr,
559                )]
560                .into_iter()
561                .collect(),
562            );
563
564            factory.primary_instance(id).unwrap();
565        }
566
567        #[test]
568        fn should_return_all_instances() {
569            let (definition, id) = create_definition();
570
571            let mut registry = MockComponentDefinitionRegistry::new();
572            registry
573                .expect_components_by_type()
574                .with(eq(id))
575                .times(1)
576                .return_const(vec![definition.clone(), definition]);
577
578            let mut factory = create_factory(registry);
579            assert_eq!(factory.instances(id).unwrap().len(), 2);
580        }
581
582        #[test]
583        fn should_return_instance_by_name() {
584            let (definition, id) = create_definition();
585
586            let mut registry = MockComponentDefinitionRegistry::new();
587            registry
588                .expect_component_by_name()
589                .with(eq("name"), eq(id))
590                .times(1)
591                .return_const(Some(definition));
592
593            let mut factory = create_factory(registry);
594            assert!(factory.instance_by_name("name", id).is_ok());
595        }
596    }
597}