Skip to main content

trellis_core/
collection.rs

1use crate::collection_diff::{MapDiff, SetDiff};
2use crate::input::downcast_input;
3use crate::{
4    CollectionDiffTrace, CollectionNode, DeriveError, DerivedNode, Graph, InputNode, NodeId,
5};
6use core::{any::Any, marker::PhantomData};
7use std::collections::{BTreeMap, BTreeSet};
8use std::sync::Arc;
9
10type CollectionComputeResult = Result<Box<dyn StoredCollection>, DeriveError>;
11type ComputeFn<C, O> =
12    dyn for<'ctx> Fn(&CollectionContext<'ctx, C, O>) -> CollectionComputeResult + Send + Sync;
13
14pub(crate) struct MapCollectionShape<K, V>(PhantomData<fn() -> (K, V)>);
15pub(crate) struct SetCollectionShape<K>(PhantomData<fn() -> K>);
16pub(crate) struct CollectionSpec<C, O> {
17    compute: Arc<ComputeFn<C, O>>,
18}
19
20impl<C, O> Clone for CollectionSpec<C, O> {
21    fn clone(&self) -> Self {
22        Self {
23            compute: Arc::clone(&self.compute),
24        }
25    }
26}
27impl<C, O> CollectionSpec<C, O> {
28    pub(crate) fn map<K, V, F>(derive: F) -> Self
29    where
30        K: Clone + Ord + Send + Sync + 'static,
31        V: Clone + PartialEq + Send + Sync + 'static,
32        F: for<'ctx> Fn(&CollectionContext<'ctx, C, O>) -> Result<BTreeMap<K, V>, DeriveError>
33            + Send
34            + Sync
35            + 'static,
36    {
37        Self {
38            compute: Arc::new(move |ctx| derive(ctx).map(boxed_map)),
39        }
40    }
41
42    pub(crate) fn set<K, F>(derive: F) -> Self
43    where
44        K: Clone + Ord + Send + Sync + 'static,
45        F: for<'ctx> Fn(&CollectionContext<'ctx, C, O>) -> Result<BTreeSet<K>, DeriveError>
46            + Send
47            + Sync
48            + 'static,
49    {
50        Self {
51            compute: Arc::new(move |ctx| derive(ctx).map(boxed_set)),
52        }
53    }
54
55    pub(crate) fn compute(
56        &self,
57        ctx: &CollectionContext<'_, C, O>,
58    ) -> Result<Box<dyn StoredCollection>, DeriveError> {
59        (self.compute)(ctx)
60    }
61}
62/// Read-only context passed to pure collection node computations.
63pub struct CollectionContext<'graph, C = (), O = ()> {
64    graph: &'graph Graph<C, O>,
65    declared_dependencies: &'graph [NodeId],
66}
67
68impl<'graph, C, O> CollectionContext<'graph, C, O> {
69    pub(crate) fn new(graph: &'graph Graph<C, O>, declared_dependencies: &'graph [NodeId]) -> Self {
70        Self {
71            graph,
72            declared_dependencies,
73        }
74    }
75
76    /// Reads a declared input dependency.
77    pub fn input<T>(&self, input: InputNode<T>) -> Result<&'graph T, DeriveError>
78    where
79        T: Clone + PartialEq + Send + Sync + 'static,
80    {
81        let node = input.id();
82        self.require_declared(node)?;
83        self.graph
84            .input_values
85            .get(&node)
86            .and_then(|value| downcast_input::<T>(value.as_ref()))
87            .ok_or(DeriveError::MissingValue(node))
88    }
89
90    /// Reads a declared scalar derived dependency.
91    pub fn derived<T>(&self, derived: DerivedNode<T>) -> Result<&'graph T, DeriveError>
92    where
93        T: Clone + PartialEq + Send + Sync + 'static,
94    {
95        let node = derived.id();
96        self.require_declared(node)?;
97        self.graph
98            .derived_values
99            .get(&node)
100            .and_then(|value| downcast_input::<T>(value.as_ref()))
101            .ok_or(DeriveError::MissingValue(node))
102    }
103
104    /// Reads a declared map collection dependency.
105    pub fn map_collection<K, V>(
106        &self,
107        collection: CollectionNode<K, V>,
108    ) -> Result<&'graph BTreeMap<K, V>, DeriveError>
109    where
110        K: Clone + Ord + Send + Sync + 'static,
111        V: Clone + PartialEq + Send + Sync + 'static,
112    {
113        let node = collection.id();
114        self.require_declared(node)?;
115        self.graph
116            .validate_map_collection_read::<K, V>(node)
117            .map_err(|_| DeriveError::WrongCollectionType(node))?;
118        self.graph
119            .collection_values
120            .get(&node)
121            .and_then(|value| downcast_map::<K, V>(value.as_ref()))
122            .ok_or(DeriveError::MissingValue(node))
123    }
124
125    /// Reads a declared set collection dependency.
126    pub fn set_collection<K>(
127        &self,
128        collection: CollectionNode<K, ()>,
129    ) -> Result<&'graph BTreeSet<K>, DeriveError>
130    where
131        K: Clone + Ord + Send + Sync + 'static,
132    {
133        let node = collection.id();
134        self.require_declared(node)?;
135        self.graph
136            .validate_set_collection_read::<K>(node)
137            .map_err(|_| DeriveError::WrongCollectionType(node))?;
138        self.graph
139            .collection_values
140            .get(&node)
141            .and_then(|value| downcast_set::<K>(value.as_ref()))
142            .ok_or(DeriveError::MissingValue(node))
143    }
144
145    fn require_declared(&self, node: NodeId) -> Result<(), DeriveError> {
146        if self.declared_dependencies.contains(&node) {
147            Ok(())
148        } else {
149            Err(DeriveError::UndeclaredDependency(node))
150        }
151    }
152}
153
154pub(crate) trait StoredCollection: Any + Send + Sync {
155    fn clone_box(&self) -> Box<dyn StoredCollection>;
156    fn empty_box(&self) -> Box<dyn StoredCollection>;
157    fn equals(&self, other: &dyn StoredCollection) -> bool;
158    fn diff(&self, next: &dyn StoredCollection) -> Box<dyn StoredDiff>;
159    fn as_any(&self) -> &dyn Any;
160}
161
162impl Clone for Box<dyn StoredCollection> {
163    fn clone(&self) -> Self {
164        self.clone_box()
165    }
166}
167
168pub(crate) trait StoredDiff: Any + Send + Sync {
169    fn clone_box(&self) -> Box<dyn StoredDiff>;
170    fn trace(&self, node: NodeId) -> CollectionDiffTrace;
171    fn as_any(&self) -> &dyn Any;
172}
173
174impl Clone for Box<dyn StoredDiff> {
175    fn clone(&self) -> Self {
176        self.clone_box()
177    }
178}
179
180#[derive(Clone)]
181struct MapCollection<K, V> {
182    value: BTreeMap<K, V>,
183}
184
185#[derive(Clone)]
186struct SetCollection<K> {
187    value: BTreeSet<K>,
188}
189
190impl<K, V> StoredCollection for MapCollection<K, V>
191where
192    K: Clone + Ord + Send + Sync + 'static,
193    V: Clone + PartialEq + Send + Sync + 'static,
194{
195    fn clone_box(&self) -> Box<dyn StoredCollection> {
196        Box::new(self.clone())
197    }
198
199    fn empty_box(&self) -> Box<dyn StoredCollection> {
200        boxed_map(BTreeMap::<K, V>::new())
201    }
202
203    fn equals(&self, other: &dyn StoredCollection) -> bool {
204        other
205            .as_any()
206            .downcast_ref::<Self>()
207            .is_some_and(|other| self.value == other.value)
208    }
209
210    fn diff(&self, next: &dyn StoredCollection) -> Box<dyn StoredDiff> {
211        let next = next
212            .as_any()
213            .downcast_ref::<Self>()
214            .expect("collection type stays stable");
215        Box::new(MapDiff::between(&self.value, &next.value))
216    }
217
218    fn as_any(&self) -> &dyn Any {
219        self
220    }
221}
222
223impl<K> StoredCollection for SetCollection<K>
224where
225    K: Clone + Ord + Send + Sync + 'static,
226{
227    fn clone_box(&self) -> Box<dyn StoredCollection> {
228        Box::new(self.clone())
229    }
230
231    fn empty_box(&self) -> Box<dyn StoredCollection> {
232        boxed_set(BTreeSet::<K>::new())
233    }
234
235    fn equals(&self, other: &dyn StoredCollection) -> bool {
236        other
237            .as_any()
238            .downcast_ref::<Self>()
239            .is_some_and(|other| self.value == other.value)
240    }
241
242    fn diff(&self, next: &dyn StoredCollection) -> Box<dyn StoredDiff> {
243        let next = next
244            .as_any()
245            .downcast_ref::<Self>()
246            .expect("collection type stays stable");
247        Box::new(SetDiff::between(&self.value, &next.value))
248    }
249
250    fn as_any(&self) -> &dyn Any {
251        self
252    }
253}
254
255pub(crate) fn boxed_map<K, V>(value: BTreeMap<K, V>) -> Box<dyn StoredCollection>
256where
257    K: Clone + Ord + Send + Sync + 'static,
258    V: Clone + PartialEq + Send + Sync + 'static,
259{
260    Box::new(MapCollection { value })
261}
262
263pub(crate) fn boxed_set<K>(value: BTreeSet<K>) -> Box<dyn StoredCollection>
264where
265    K: Clone + Ord + Send + Sync + 'static,
266{
267    Box::new(SetCollection { value })
268}
269
270pub(crate) fn downcast_map<K, V>(value: &dyn StoredCollection) -> Option<&BTreeMap<K, V>>
271where
272    K: Clone + Ord + Send + Sync + 'static,
273    V: Clone + PartialEq + Send + Sync + 'static,
274{
275    value
276        .as_any()
277        .downcast_ref::<MapCollection<K, V>>()
278        .map(|collection| &collection.value)
279}
280
281pub(crate) fn downcast_set<K>(value: &dyn StoredCollection) -> Option<&BTreeSet<K>>
282where
283    K: Clone + Ord + Send + Sync + 'static,
284{
285    value
286        .as_any()
287        .downcast_ref::<SetCollection<K>>()
288        .map(|collection| &collection.value)
289}
290
291pub(crate) fn downcast_map_diff<K, V>(value: &dyn StoredDiff) -> Option<&MapDiff<K, V>>
292where
293    K: Clone + Ord + Send + Sync + 'static,
294    V: Clone + PartialEq + Send + Sync + 'static,
295{
296    value.as_any().downcast_ref::<MapDiff<K, V>>()
297}
298
299pub(crate) fn downcast_set_diff<K>(value: &dyn StoredDiff) -> Option<&SetDiff<K>>
300where
301    K: Clone + Ord + Send + Sync + 'static,
302{
303    value.as_any().downcast_ref::<SetDiff<K>>()
304}