1use async_trait::async_trait;
2use std::any::Any;
3use std::collections::HashMap;
4use std::sync::Arc;
5
6#[async_trait]
7pub trait Supplier: Send + Sync {
8 type Input: Send + 'static;
9 type Output: Send + 'static;
10
11 async fn supply(&self, input: Self::Input, scope: Arc<Scope>) -> Self::Output;
12}
13
14#[async_trait]
15pub trait ErasedSupplier: Send + Sync {
16 async fn supply_erased(
17 &self,
18 input: Box<dyn Any + Send>,
19 scope: Arc<Scope>,
20 ) -> Box<dyn Any + Send>;
21}
22
23#[async_trait]
25impl<T> ErasedSupplier for T
26where
27 T: Supplier + Send + Sync,
28{
29 async fn supply_erased(
30 &self,
31 input: Box<dyn Any + Send>,
32 scope: Arc<Scope>,
33 ) -> Box<dyn Any + Send> {
34 let input = *input.downcast::<T::Input>().expect("Input type mismatch");
35 let out = self.supply(input, scope).await;
36 Box::new(out)
37 }
38}
39
40pub type SupplierRegistry = HashMap<String, Arc<dyn ErasedSupplier>>;
41
42pub struct Scope {
43 pub registry: Arc<SupplierRegistry>,
44}
45
46pub struct Demand {
47 pub type_: String,
48 pub override_suppliers: Option<SupplierRegistry>,
49}
50
51impl Scope {
52 pub async fn demand<T: Send + 'static>(&self, demand: Demand, input: Box<dyn Any + Send>) -> T {
53 let registry = if let Some(overrides) = &demand.override_suppliers {
54 let mut new = (*self.registry).clone();
55 for (k, v) in overrides.iter() {
56 new.insert(k.clone(), v.clone());
57 }
58 Arc::new(new)
59 } else {
60 self.registry.clone()
61 };
62 let new_scope = Arc::new(Scope { registry });
63
64 let supplier = new_scope
65 .registry
66 .get(&demand.type_)
67 .expect("Supplier not found")
68 .clone();
69
70 let result = supplier.supply_erased(input, new_scope).await;
71 *result.downcast::<T>().expect("Output type mismatch")
72 }
73}