Skip to main content

trellis_core/
collection.rs

1use crate::collection_diff::{MapDiff, SetDiff};
2use crate::input::downcast_input;
3use crate::{
4    CollectionDiffTrace, CollectionNode, DeriveError, DerivedNode, Graph, InputNode, NodeId,
5};
6use core::{any::Any, marker::PhantomData};
7use std::collections::{BTreeMap, BTreeSet};
8use std::sync::Arc;
9
10type CollectionComputeResult = Result<Box<dyn StoredCollection>, DeriveError>;
11type ComputeFn<C, O> = dyn for<'ctx> Fn(&CollectionContext<'ctx, C, O>) -> CollectionComputeResult;
12
13pub(crate) struct MapCollectionShape<K, V>(PhantomData<fn() -> (K, V)>);
14pub(crate) struct SetCollectionShape<K>(PhantomData<fn() -> K>);
15pub(crate) struct CollectionSpec<C, O> {
16    compute: Arc<ComputeFn<C, O>>,
17}
18
19impl<C, O> Clone for CollectionSpec<C, O> {
20    fn clone(&self) -> Self {
21        Self {
22            compute: Arc::clone(&self.compute),
23        }
24    }
25}
26impl<C, O> CollectionSpec<C, O> {
27    pub(crate) fn map<K, V, F>(derive: F) -> Self
28    where
29        K: Clone + Ord + 'static,
30        V: Clone + PartialEq + 'static,
31        F: for<'ctx> Fn(&CollectionContext<'ctx, C, O>) -> Result<BTreeMap<K, V>, DeriveError>
32            + 'static,
33    {
34        Self {
35            compute: Arc::new(move |ctx| derive(ctx).map(boxed_map)),
36        }
37    }
38
39    pub(crate) fn set<K, F>(derive: F) -> Self
40    where
41        K: Clone + Ord + 'static,
42        F: for<'ctx> Fn(&CollectionContext<'ctx, C, O>) -> Result<BTreeSet<K>, DeriveError>
43            + 'static,
44    {
45        Self {
46            compute: Arc::new(move |ctx| derive(ctx).map(boxed_set)),
47        }
48    }
49
50    pub(crate) fn compute(
51        &self,
52        ctx: &CollectionContext<'_, C, O>,
53    ) -> Result<Box<dyn StoredCollection>, DeriveError> {
54        (self.compute)(ctx)
55    }
56}
57/// Read-only context passed to pure collection node computations.
58pub struct CollectionContext<'graph, C = (), O = ()> {
59    graph: &'graph Graph<C, O>,
60    declared_dependencies: &'graph [NodeId],
61}
62
63impl<'graph, C, O> CollectionContext<'graph, C, O> {
64    pub(crate) fn new(graph: &'graph Graph<C, O>, declared_dependencies: &'graph [NodeId]) -> Self {
65        Self {
66            graph,
67            declared_dependencies,
68        }
69    }
70
71    /// Reads a declared input dependency.
72    pub fn input<T>(&self, input: InputNode<T>) -> Result<&'graph T, DeriveError>
73    where
74        T: Clone + PartialEq + 'static,
75    {
76        let node = input.id();
77        self.require_declared(node)?;
78        self.graph
79            .input_values
80            .get(&node)
81            .and_then(|value| downcast_input::<T>(value.as_ref()))
82            .ok_or(DeriveError::MissingValue(node))
83    }
84
85    /// Reads a declared scalar derived dependency.
86    pub fn derived<T>(&self, derived: DerivedNode<T>) -> Result<&'graph T, DeriveError>
87    where
88        T: Clone + PartialEq + 'static,
89    {
90        let node = derived.id();
91        self.require_declared(node)?;
92        self.graph
93            .derived_values
94            .get(&node)
95            .and_then(|value| downcast_input::<T>(value.as_ref()))
96            .ok_or(DeriveError::MissingValue(node))
97    }
98
99    /// Reads a declared map collection dependency.
100    pub fn map_collection<K, V>(
101        &self,
102        collection: CollectionNode<K, V>,
103    ) -> Result<&'graph BTreeMap<K, V>, DeriveError>
104    where
105        K: Clone + Ord + 'static,
106        V: Clone + PartialEq + 'static,
107    {
108        let node = collection.id();
109        self.require_declared(node)?;
110        self.graph
111            .validate_map_collection_read::<K, V>(node)
112            .map_err(|_| DeriveError::WrongCollectionType(node))?;
113        self.graph
114            .collection_values
115            .get(&node)
116            .and_then(|value| downcast_map::<K, V>(value.as_ref()))
117            .ok_or(DeriveError::MissingValue(node))
118    }
119
120    /// Reads a declared set collection dependency.
121    pub fn set_collection<K>(
122        &self,
123        collection: CollectionNode<K, ()>,
124    ) -> Result<&'graph BTreeSet<K>, DeriveError>
125    where
126        K: Clone + Ord + 'static,
127    {
128        let node = collection.id();
129        self.require_declared(node)?;
130        self.graph
131            .validate_set_collection_read::<K>(node)
132            .map_err(|_| DeriveError::WrongCollectionType(node))?;
133        self.graph
134            .collection_values
135            .get(&node)
136            .and_then(|value| downcast_set::<K>(value.as_ref()))
137            .ok_or(DeriveError::MissingValue(node))
138    }
139
140    fn require_declared(&self, node: NodeId) -> Result<(), DeriveError> {
141        if self.declared_dependencies.contains(&node) {
142            Ok(())
143        } else {
144            Err(DeriveError::UndeclaredDependency(node))
145        }
146    }
147}
148
149pub(crate) trait StoredCollection: Any {
150    fn clone_box(&self) -> Box<dyn StoredCollection>;
151    fn empty_box(&self) -> Box<dyn StoredCollection>;
152    fn equals(&self, other: &dyn StoredCollection) -> bool;
153    fn diff(&self, next: &dyn StoredCollection) -> Box<dyn StoredDiff>;
154    fn as_any(&self) -> &dyn Any;
155}
156
157impl Clone for Box<dyn StoredCollection> {
158    fn clone(&self) -> Self {
159        self.clone_box()
160    }
161}
162
163pub(crate) trait StoredDiff: Any {
164    fn clone_box(&self) -> Box<dyn StoredDiff>;
165    fn trace(&self, node: NodeId) -> CollectionDiffTrace;
166    fn as_any(&self) -> &dyn Any;
167}
168
169impl Clone for Box<dyn StoredDiff> {
170    fn clone(&self) -> Self {
171        self.clone_box()
172    }
173}
174
175#[derive(Clone)]
176struct MapCollection<K, V> {
177    value: BTreeMap<K, V>,
178}
179
180#[derive(Clone)]
181struct SetCollection<K> {
182    value: BTreeSet<K>,
183}
184
185impl<K, V> StoredCollection for MapCollection<K, V>
186where
187    K: Clone + Ord + 'static,
188    V: Clone + PartialEq + 'static,
189{
190    fn clone_box(&self) -> Box<dyn StoredCollection> {
191        Box::new(self.clone())
192    }
193
194    fn empty_box(&self) -> Box<dyn StoredCollection> {
195        boxed_map(BTreeMap::<K, V>::new())
196    }
197
198    fn equals(&self, other: &dyn StoredCollection) -> bool {
199        other
200            .as_any()
201            .downcast_ref::<Self>()
202            .is_some_and(|other| self.value == other.value)
203    }
204
205    fn diff(&self, next: &dyn StoredCollection) -> Box<dyn StoredDiff> {
206        let next = next
207            .as_any()
208            .downcast_ref::<Self>()
209            .expect("collection type stays stable");
210        Box::new(MapDiff::between(&self.value, &next.value))
211    }
212
213    fn as_any(&self) -> &dyn Any {
214        self
215    }
216}
217
218impl<K> StoredCollection for SetCollection<K>
219where
220    K: Clone + Ord + 'static,
221{
222    fn clone_box(&self) -> Box<dyn StoredCollection> {
223        Box::new(self.clone())
224    }
225
226    fn empty_box(&self) -> Box<dyn StoredCollection> {
227        boxed_set(BTreeSet::<K>::new())
228    }
229
230    fn equals(&self, other: &dyn StoredCollection) -> bool {
231        other
232            .as_any()
233            .downcast_ref::<Self>()
234            .is_some_and(|other| self.value == other.value)
235    }
236
237    fn diff(&self, next: &dyn StoredCollection) -> Box<dyn StoredDiff> {
238        let next = next
239            .as_any()
240            .downcast_ref::<Self>()
241            .expect("collection type stays stable");
242        Box::new(SetDiff::between(&self.value, &next.value))
243    }
244
245    fn as_any(&self) -> &dyn Any {
246        self
247    }
248}
249
250pub(crate) fn boxed_map<K, V>(value: BTreeMap<K, V>) -> Box<dyn StoredCollection>
251where
252    K: Clone + Ord + 'static,
253    V: Clone + PartialEq + 'static,
254{
255    Box::new(MapCollection { value })
256}
257
258pub(crate) fn boxed_set<K>(value: BTreeSet<K>) -> Box<dyn StoredCollection>
259where
260    K: Clone + Ord + 'static,
261{
262    Box::new(SetCollection { value })
263}
264
265pub(crate) fn downcast_map<K, V>(value: &dyn StoredCollection) -> Option<&BTreeMap<K, V>>
266where
267    K: Clone + Ord + 'static,
268    V: Clone + PartialEq + 'static,
269{
270    value
271        .as_any()
272        .downcast_ref::<MapCollection<K, V>>()
273        .map(|collection| &collection.value)
274}
275
276pub(crate) fn downcast_set<K>(value: &dyn StoredCollection) -> Option<&BTreeSet<K>>
277where
278    K: Clone + Ord + 'static,
279{
280    value
281        .as_any()
282        .downcast_ref::<SetCollection<K>>()
283        .map(|collection| &collection.value)
284}
285
286pub(crate) fn downcast_map_diff<K, V>(value: &dyn StoredDiff) -> Option<&MapDiff<K, V>>
287where
288    K: Clone + Ord + 'static,
289    V: Clone + PartialEq + 'static,
290{
291    value.as_any().downcast_ref::<MapDiff<K, V>>()
292}
293
294pub(crate) fn downcast_set_diff<K>(value: &dyn StoredDiff) -> Option<&SetDiff<K>>
295where
296    K: Clone + Ord + 'static,
297{
298    value.as_any().downcast_ref::<SetDiff<K>>()
299}