supply_demand/
lib.rs

1use async_trait::async_trait;
2use std::any::Any;
3use std::collections::HashMap;
4use std::sync::Arc;
5
6#[async_trait]
7pub trait Supplier: Send + Sync {
8    type Input: Send + 'static;
9    type Output: Send + 'static;
10
11    async fn supply(&self, input: Self::Input, scope: Arc<Scope>) -> Self::Output;
12}
13
14#[async_trait]
15pub trait ErasedSupplier: Send + Sync {
16    async fn supply_erased(
17        &self,
18        input: Box<dyn Any + Send>,
19        scope: Arc<Scope>,
20    ) -> Box<dyn Any + Send>;
21}
22
23// Blanket impl: any strongly-typed supplier is an ErasedSupplier
24#[async_trait]
25impl<T> ErasedSupplier for T
26where
27    T: Supplier + Send + Sync,
28{
29    async fn supply_erased(
30        &self,
31        input: Box<dyn Any + Send>,
32        scope: Arc<Scope>,
33    ) -> Box<dyn Any + Send> {
34        let input = *input.downcast::<T::Input>().expect("Input type mismatch");
35        let out = self.supply(input, scope).await;
36        Box::new(out)
37    }
38}
39
40pub type SupplierRegistry = HashMap<String, Arc<dyn ErasedSupplier>>;
41
42pub struct Scope {
43    pub registry: Arc<SupplierRegistry>,
44}
45
46pub struct Demand {
47    pub type_: String,
48    pub override_suppliers: Option<SupplierRegistry>,
49}
50
51impl Scope {
52    pub async fn demand<T: Send + 'static>(&self, demand: Demand, input: Box<dyn Any + Send>) -> T {
53        let registry = if let Some(overrides) = &demand.override_suppliers {
54            let mut new = (*self.registry).clone();
55            for (k, v) in overrides.iter() {
56                new.insert(k.clone(), v.clone());
57            }
58            Arc::new(new)
59        } else {
60            self.registry.clone()
61        };
62        let new_scope = Arc::new(Scope { registry });
63
64        let supplier = new_scope
65            .registry
66            .get(&demand.type_)
67            .expect("Supplier not found")
68            .clone();
69
70        let result = supplier.supply_erased(input, new_scope).await;
71        *result.downcast::<T>().expect("Output type mismatch")
72    }
73}