Skip to main content

trellis_core/
collection.rs

1use crate::collection_diff::{MapDiff, SetDiff};
2use crate::input::downcast_input;
3use crate::{CollectionNode, DeriveError, DerivedNode, Graph, InputNode, NodeId};
4use core::any::Any;
5use core::marker::PhantomData;
6use std::collections::{BTreeMap, BTreeSet};
7use std::sync::Arc;
8
9type ComputeFn<C, O> = dyn for<'ctx> Fn(
10    &CollectionContext<'ctx, C, O>,
11) -> Result<Box<dyn StoredCollection>, DeriveError>;
12
13pub(crate) struct MapCollectionShape<K, V>(PhantomData<fn() -> (K, V)>);
14pub(crate) struct SetCollectionShape<K>(PhantomData<fn() -> K>);
15pub(crate) struct CollectionSpec<C, O> {
16    compute: Arc<ComputeFn<C, O>>,
17}
18
19impl<C, O> Clone for CollectionSpec<C, O> {
20    fn clone(&self) -> Self {
21        Self {
22            compute: Arc::clone(&self.compute),
23        }
24    }
25}
26impl<C, O> CollectionSpec<C, O> {
27    pub(crate) fn map<K, V, F>(derive: F) -> Self
28    where
29        K: Clone + Ord + 'static,
30        V: Clone + PartialEq + 'static,
31        F: for<'ctx> Fn(&CollectionContext<'ctx, C, O>) -> Result<BTreeMap<K, V>, DeriveError>
32            + 'static,
33    {
34        Self {
35            compute: Arc::new(move |ctx| derive(ctx).map(boxed_map)),
36        }
37    }
38
39    pub(crate) fn set<K, F>(derive: F) -> Self
40    where
41        K: Clone + Ord + 'static,
42        F: for<'ctx> Fn(&CollectionContext<'ctx, C, O>) -> Result<BTreeSet<K>, DeriveError>
43            + 'static,
44    {
45        Self {
46            compute: Arc::new(move |ctx| derive(ctx).map(boxed_set)),
47        }
48    }
49
50    pub(crate) fn compute(
51        &self,
52        ctx: &CollectionContext<'_, C, O>,
53    ) -> Result<Box<dyn StoredCollection>, DeriveError> {
54        (self.compute)(ctx)
55    }
56}
57/// Read-only context passed to pure collection node computations.
58pub struct CollectionContext<'graph, C = (), O = ()> {
59    graph: &'graph Graph<C, O>,
60    declared_dependencies: &'graph [NodeId],
61}
62
63impl<'graph, C, O> CollectionContext<'graph, C, O> {
64    pub(crate) fn new(graph: &'graph Graph<C, O>, declared_dependencies: &'graph [NodeId]) -> Self {
65        Self {
66            graph,
67            declared_dependencies,
68        }
69    }
70
71    /// Reads a declared input dependency.
72    pub fn input<T>(&self, input: InputNode<T>) -> Result<&'graph T, DeriveError>
73    where
74        T: Clone + PartialEq + 'static,
75    {
76        let node = input.id();
77        self.require_declared(node)?;
78        self.graph
79            .input_values
80            .get(&node)
81            .and_then(|value| downcast_input::<T>(value.as_ref()))
82            .ok_or(DeriveError::MissingValue(node))
83    }
84
85    /// Reads a declared scalar derived dependency.
86    pub fn derived<T>(&self, derived: DerivedNode<T>) -> Result<&'graph T, DeriveError>
87    where
88        T: Clone + PartialEq + 'static,
89    {
90        let node = derived.id();
91        self.require_declared(node)?;
92        self.graph
93            .derived_values
94            .get(&node)
95            .and_then(|value| downcast_input::<T>(value.as_ref()))
96            .ok_or(DeriveError::MissingValue(node))
97    }
98
99    /// Reads a declared map collection dependency.
100    pub fn map_collection<K, V>(
101        &self,
102        collection: CollectionNode<K, V>,
103    ) -> Result<&'graph BTreeMap<K, V>, DeriveError>
104    where
105        K: Clone + Ord + 'static,
106        V: Clone + PartialEq + 'static,
107    {
108        let node = collection.id();
109        self.require_declared(node)?;
110        self.graph
111            .validate_map_collection_read::<K, V>(node)
112            .map_err(|_| DeriveError::WrongCollectionType(node))?;
113        self.graph
114            .collection_values
115            .get(&node)
116            .and_then(|value| downcast_map::<K, V>(value.as_ref()))
117            .ok_or(DeriveError::MissingValue(node))
118    }
119
120    /// Reads a declared set collection dependency.
121    pub fn set_collection<K>(
122        &self,
123        collection: CollectionNode<K, ()>,
124    ) -> Result<&'graph BTreeSet<K>, DeriveError>
125    where
126        K: Clone + Ord + 'static,
127    {
128        let node = collection.id();
129        self.require_declared(node)?;
130        self.graph
131            .validate_set_collection_read::<K>(node)
132            .map_err(|_| DeriveError::WrongCollectionType(node))?;
133        self.graph
134            .collection_values
135            .get(&node)
136            .and_then(|value| downcast_set::<K>(value.as_ref()))
137            .ok_or(DeriveError::MissingValue(node))
138    }
139
140    fn require_declared(&self, node: NodeId) -> Result<(), DeriveError> {
141        if self.declared_dependencies.contains(&node) {
142            Ok(())
143        } else {
144            Err(DeriveError::UndeclaredDependency(node))
145        }
146    }
147}
148
149pub(crate) trait StoredCollection: Any {
150    fn clone_box(&self) -> Box<dyn StoredCollection>;
151    fn empty_box(&self) -> Box<dyn StoredCollection>;
152    fn equals(&self, other: &dyn StoredCollection) -> bool;
153    fn diff(&self, next: &dyn StoredCollection) -> Box<dyn StoredDiff>;
154    fn as_any(&self) -> &dyn Any;
155}
156
157impl Clone for Box<dyn StoredCollection> {
158    fn clone(&self) -> Self {
159        self.clone_box()
160    }
161}
162
163pub(crate) trait StoredDiff: Any {
164    fn clone_box(&self) -> Box<dyn StoredDiff>;
165    fn as_any(&self) -> &dyn Any;
166}
167
168impl Clone for Box<dyn StoredDiff> {
169    fn clone(&self) -> Self {
170        self.clone_box()
171    }
172}
173
174#[derive(Clone)]
175struct MapCollection<K, V> {
176    value: BTreeMap<K, V>,
177}
178
179#[derive(Clone)]
180struct SetCollection<K> {
181    value: BTreeSet<K>,
182}
183
184impl<K, V> StoredCollection for MapCollection<K, V>
185where
186    K: Clone + Ord + 'static,
187    V: Clone + PartialEq + 'static,
188{
189    fn clone_box(&self) -> Box<dyn StoredCollection> {
190        Box::new(self.clone())
191    }
192
193    fn empty_box(&self) -> Box<dyn StoredCollection> {
194        boxed_map(BTreeMap::<K, V>::new())
195    }
196
197    fn equals(&self, other: &dyn StoredCollection) -> bool {
198        other
199            .as_any()
200            .downcast_ref::<Self>()
201            .is_some_and(|other| self.value == other.value)
202    }
203
204    fn diff(&self, next: &dyn StoredCollection) -> Box<dyn StoredDiff> {
205        let next = next
206            .as_any()
207            .downcast_ref::<Self>()
208            .expect("collection type stays stable");
209        Box::new(MapDiff::between(&self.value, &next.value))
210    }
211
212    fn as_any(&self) -> &dyn Any {
213        self
214    }
215}
216
217impl<K> StoredCollection for SetCollection<K>
218where
219    K: Clone + Ord + 'static,
220{
221    fn clone_box(&self) -> Box<dyn StoredCollection> {
222        Box::new(self.clone())
223    }
224
225    fn empty_box(&self) -> Box<dyn StoredCollection> {
226        boxed_set(BTreeSet::<K>::new())
227    }
228
229    fn equals(&self, other: &dyn StoredCollection) -> bool {
230        other
231            .as_any()
232            .downcast_ref::<Self>()
233            .is_some_and(|other| self.value == other.value)
234    }
235
236    fn diff(&self, next: &dyn StoredCollection) -> Box<dyn StoredDiff> {
237        let next = next
238            .as_any()
239            .downcast_ref::<Self>()
240            .expect("collection type stays stable");
241        Box::new(SetDiff::between(&self.value, &next.value))
242    }
243
244    fn as_any(&self) -> &dyn Any {
245        self
246    }
247}
248
249pub(crate) fn boxed_map<K, V>(value: BTreeMap<K, V>) -> Box<dyn StoredCollection>
250where
251    K: Clone + Ord + 'static,
252    V: Clone + PartialEq + 'static,
253{
254    Box::new(MapCollection { value })
255}
256
257pub(crate) fn boxed_set<K>(value: BTreeSet<K>) -> Box<dyn StoredCollection>
258where
259    K: Clone + Ord + 'static,
260{
261    Box::new(SetCollection { value })
262}
263
264pub(crate) fn downcast_map<K, V>(value: &dyn StoredCollection) -> Option<&BTreeMap<K, V>>
265where
266    K: Clone + Ord + 'static,
267    V: Clone + PartialEq + 'static,
268{
269    value
270        .as_any()
271        .downcast_ref::<MapCollection<K, V>>()
272        .map(|collection| &collection.value)
273}
274
275pub(crate) fn downcast_set<K>(value: &dyn StoredCollection) -> Option<&BTreeSet<K>>
276where
277    K: Clone + Ord + 'static,
278{
279    value
280        .as_any()
281        .downcast_ref::<SetCollection<K>>()
282        .map(|collection| &collection.value)
283}
284
285pub(crate) fn downcast_map_diff<K, V>(value: &dyn StoredDiff) -> Option<&MapDiff<K, V>>
286where
287    K: Clone + Ord + 'static,
288    V: Clone + PartialEq + 'static,
289{
290    value.as_any().downcast_ref::<MapDiff<K, V>>()
291}
292
293pub(crate) fn downcast_set_diff<K>(value: &dyn StoredDiff) -> Option<&SetDiff<K>>
294where
295    K: Clone + Ord + 'static,
296{
297    value.as_any().downcast_ref::<SetDiff<K>>()
298}