Skip to main content

trellis_core/
collection.rs

1pub(crate) use crate::collection_storage::{
2    StoredCollection, StoredDiff, boxed_map, boxed_set, downcast_map, downcast_map_diff,
3    downcast_set, downcast_set_diff,
4};
5use crate::input::downcast_input;
6use crate::{CollectionNode, DeriveError, DerivedNode, Graph, InputNode, NodeId};
7use core::marker::PhantomData;
8use std::collections::{BTreeMap, BTreeSet};
9use std::sync::Arc;
10
11type CollectionComputeResult = Result<Box<dyn StoredCollection>, DeriveError>;
12type ComputeFn<C> =
13    dyn for<'ctx> Fn(&CollectionContext<'ctx, C>) -> CollectionComputeResult + Send + Sync;
14
15pub(crate) struct MapCollectionShape<K, V>(PhantomData<fn() -> (K, V)>);
16pub(crate) struct SetCollectionShape<K>(PhantomData<fn() -> K>);
17pub(crate) struct CollectionSpec<C> {
18    compute: Arc<ComputeFn<C>>,
19}
20
21impl<C> Clone for CollectionSpec<C> {
22    fn clone(&self) -> Self {
23        Self {
24            compute: Arc::clone(&self.compute),
25        }
26    }
27}
28impl<C> CollectionSpec<C> {
29    pub(crate) fn map<K, V, F>(derive: F) -> Self
30    where
31        K: Clone + Ord + Send + Sync + 'static,
32        V: Clone + PartialEq + Send + Sync + 'static,
33        F: for<'ctx> Fn(&CollectionContext<'ctx, C>) -> Result<BTreeMap<K, V>, DeriveError>
34            + Send
35            + Sync
36            + 'static,
37    {
38        Self {
39            compute: Arc::new(move |ctx| derive(ctx).map(boxed_map)),
40        }
41    }
42
43    pub(crate) fn set<K, F>(derive: F) -> Self
44    where
45        K: Clone + Ord + Send + Sync + 'static,
46        F: for<'ctx> Fn(&CollectionContext<'ctx, C>) -> Result<BTreeSet<K>, DeriveError>
47            + Send
48            + Sync
49            + 'static,
50    {
51        Self {
52            compute: Arc::new(move |ctx| derive(ctx).map(boxed_set)),
53        }
54    }
55
56    pub(crate) fn compute(
57        &self,
58        ctx: &CollectionContext<'_, C>,
59    ) -> Result<Box<dyn StoredCollection>, DeriveError> {
60        (self.compute)(ctx)
61    }
62}
63/// Read-only context passed to pure collection node computations.
64pub struct CollectionContext<'graph, C = ()> {
65    graph: &'graph Graph<C>,
66    declared_dependencies: &'graph [NodeId],
67}
68
69impl<'graph, C> CollectionContext<'graph, C> {
70    pub(crate) fn new(graph: &'graph Graph<C>, declared_dependencies: &'graph [NodeId]) -> Self {
71        Self {
72            graph,
73            declared_dependencies,
74        }
75    }
76
77    /// Reads a declared input dependency.
78    pub fn input<T>(&self, input: InputNode<T>) -> Result<&'graph T, DeriveError>
79    where
80        T: Clone + PartialEq + Send + Sync + 'static,
81    {
82        let node = input.id();
83        self.require_declared(node)?;
84        self.graph
85            .input_values
86            .get(&node)
87            .and_then(|value| downcast_input::<T>(value.as_ref()))
88            .ok_or(DeriveError::MissingValue(node))
89    }
90
91    /// Reads a declared scalar derived dependency.
92    pub fn derived<T>(&self, derived: DerivedNode<T>) -> Result<&'graph T, DeriveError>
93    where
94        T: Clone + PartialEq + Send + Sync + 'static,
95    {
96        let node = derived.id();
97        self.require_declared(node)?;
98        self.graph
99            .derived_values
100            .get(&node)
101            .and_then(|value| downcast_input::<T>(value.as_ref()))
102            .ok_or(DeriveError::MissingValue(node))
103    }
104
105    /// Reads a declared map collection dependency.
106    pub fn map_collection<K, V>(
107        &self,
108        collection: CollectionNode<K, V>,
109    ) -> Result<&'graph BTreeMap<K, V>, DeriveError>
110    where
111        K: Clone + Ord + Send + Sync + 'static,
112        V: Clone + PartialEq + Send + Sync + 'static,
113    {
114        let node = collection.id();
115        self.require_declared(node)?;
116        self.graph
117            .validate_map_collection_read::<K, V>(node)
118            .map_err(|_| DeriveError::WrongCollectionType(node))?;
119        self.graph
120            .collection_values
121            .get(&node)
122            .and_then(|value| downcast_map::<K, V>(value.as_ref()))
123            .ok_or(DeriveError::MissingValue(node))
124    }
125
126    /// Reads a declared set collection dependency.
127    pub fn set_collection<K>(
128        &self,
129        collection: CollectionNode<K, ()>,
130    ) -> Result<&'graph BTreeSet<K>, DeriveError>
131    where
132        K: Clone + Ord + Send + Sync + 'static,
133    {
134        let node = collection.id();
135        self.require_declared(node)?;
136        self.graph
137            .validate_set_collection_read::<K>(node)
138            .map_err(|_| DeriveError::WrongCollectionType(node))?;
139        self.graph
140            .collection_values
141            .get(&node)
142            .and_then(|value| downcast_set::<K>(value.as_ref()))
143            .ok_or(DeriveError::MissingValue(node))
144    }
145
146    fn require_declared(&self, node: NodeId) -> Result<(), DeriveError> {
147        if self.declared_dependencies.contains(&node) {
148            Ok(())
149        } else {
150            Err(DeriveError::UndeclaredDependency(node))
151        }
152    }
153}