trellis_core/
collection.rs1pub(crate) use crate::collection_storage::{
2 StoredCollection, StoredDiff, boxed_map, boxed_set, downcast_map, downcast_map_diff,
3 downcast_set, downcast_set_diff,
4};
5use crate::input::downcast_input;
6use crate::{CollectionNode, DeriveError, DerivedNode, Graph, InputNode, NodeId};
7use core::marker::PhantomData;
8use std::collections::{BTreeMap, BTreeSet};
9use std::sync::Arc;
10
11type CollectionComputeResult = Result<Box<dyn StoredCollection>, DeriveError>;
12type ComputeFn<C> =
13 dyn for<'ctx> Fn(&CollectionContext<'ctx, C>) -> CollectionComputeResult + Send + Sync;
14
15pub(crate) struct MapCollectionShape<K, V>(PhantomData<fn() -> (K, V)>);
16pub(crate) struct SetCollectionShape<K>(PhantomData<fn() -> K>);
17pub(crate) struct CollectionSpec<C> {
18 compute: Arc<ComputeFn<C>>,
19}
20
21impl<C> Clone for CollectionSpec<C> {
22 fn clone(&self) -> Self {
23 Self {
24 compute: Arc::clone(&self.compute),
25 }
26 }
27}
28impl<C> CollectionSpec<C> {
29 pub(crate) fn map<K, V, F>(derive: F) -> Self
30 where
31 K: Clone + Ord + Send + Sync + 'static,
32 V: Clone + PartialEq + Send + Sync + 'static,
33 F: for<'ctx> Fn(&CollectionContext<'ctx, C>) -> Result<BTreeMap<K, V>, DeriveError>
34 + Send
35 + Sync
36 + 'static,
37 {
38 Self {
39 compute: Arc::new(move |ctx| derive(ctx).map(boxed_map)),
40 }
41 }
42
43 pub(crate) fn set<K, F>(derive: F) -> Self
44 where
45 K: Clone + Ord + Send + Sync + 'static,
46 F: for<'ctx> Fn(&CollectionContext<'ctx, C>) -> Result<BTreeSet<K>, DeriveError>
47 + Send
48 + Sync
49 + 'static,
50 {
51 Self {
52 compute: Arc::new(move |ctx| derive(ctx).map(boxed_set)),
53 }
54 }
55
56 pub(crate) fn compute(
57 &self,
58 ctx: &CollectionContext<'_, C>,
59 ) -> Result<Box<dyn StoredCollection>, DeriveError> {
60 (self.compute)(ctx)
61 }
62}
63pub struct CollectionContext<'graph, C = ()> {
65 graph: &'graph Graph<C>,
66 declared_dependencies: &'graph [NodeId],
67}
68
69impl<'graph, C> CollectionContext<'graph, C> {
70 pub(crate) fn new(graph: &'graph Graph<C>, declared_dependencies: &'graph [NodeId]) -> Self {
71 Self {
72 graph,
73 declared_dependencies,
74 }
75 }
76
77 pub fn input<T>(&self, input: InputNode<T>) -> Result<&'graph T, DeriveError>
79 where
80 T: Clone + PartialEq + Send + Sync + 'static,
81 {
82 let node = input.id();
83 self.require_declared(node)?;
84 self.graph
85 .input_values
86 .get(&node)
87 .and_then(|value| downcast_input::<T>(value.as_ref()))
88 .ok_or(DeriveError::MissingValue(node))
89 }
90
91 pub fn derived<T>(&self, derived: DerivedNode<T>) -> Result<&'graph T, DeriveError>
93 where
94 T: Clone + PartialEq + Send + Sync + 'static,
95 {
96 let node = derived.id();
97 self.require_declared(node)?;
98 self.graph
99 .derived_values
100 .get(&node)
101 .and_then(|value| downcast_input::<T>(value.as_ref()))
102 .ok_or(DeriveError::MissingValue(node))
103 }
104
105 pub fn map_collection<K, V>(
107 &self,
108 collection: CollectionNode<K, V>,
109 ) -> Result<&'graph BTreeMap<K, V>, DeriveError>
110 where
111 K: Clone + Ord + Send + Sync + 'static,
112 V: Clone + PartialEq + Send + Sync + 'static,
113 {
114 let node = collection.id();
115 self.require_declared(node)?;
116 self.graph
117 .validate_map_collection_read::<K, V>(node)
118 .map_err(|_| DeriveError::WrongCollectionType(node))?;
119 self.graph
120 .collection_values
121 .get(&node)
122 .and_then(|value| downcast_map::<K, V>(value.as_ref()))
123 .ok_or(DeriveError::MissingValue(node))
124 }
125
126 pub fn set_collection<K>(
128 &self,
129 collection: CollectionNode<K, ()>,
130 ) -> Result<&'graph BTreeSet<K>, DeriveError>
131 where
132 K: Clone + Ord + Send + Sync + 'static,
133 {
134 let node = collection.id();
135 self.require_declared(node)?;
136 self.graph
137 .validate_set_collection_read::<K>(node)
138 .map_err(|_| DeriveError::WrongCollectionType(node))?;
139 self.graph
140 .collection_values
141 .get(&node)
142 .and_then(|value| downcast_set::<K>(value.as_ref()))
143 .ok_or(DeriveError::MissingValue(node))
144 }
145
146 fn require_declared(&self, node: NodeId) -> Result<(), DeriveError> {
147 if self.declared_dependencies.contains(&node) {
148 Ok(())
149 } else {
150 Err(DeriveError::UndeclaredDependency(node))
151 }
152 }
153}